diff --git a/docs/mapstore.md b/docs/mapstore.md index a47c670..ebaf965 100644 --- a/docs/mapstore.md +++ b/docs/mapstore.md @@ -1,12 +1,12 @@ -# MapStore - - +# MapStore +- [Introduction](#introduction) - [Implementations](#implementations) - * [SimpleMap](#simplemap) - * [BadgerV4](#badgerv4) + - [SimpleMap](#simplemap) + - [BadgerV4](#badgerv4) +- [Note On External Writability](#note-on-external-writability) - +## Introduction The `MapStore` is a simple interface used by the SM(S)T to store, delete and retrieve key-value pairs. It is intentionally simple and minimalistic so as to @@ -31,11 +31,15 @@ See [simplemap.go](../kvstore/simplemap/simplemap.go) for more details. ### BadgerV4 -This library provides a wrapper around [dgraph-io/badger][badgerv4] to adhere -to the `MapStore` interface. See the [full documentation](./badger-store.md) -for additional functionality and implementation details. +This library provides a wrapper around [dgraph-io/badger][https://github.com/dgraph-io/badger] to adhere to +the `MapStore` interface. See the [full documentation](./badger-store.md) for +additional functionality and implementation details. + +See: [badger](../kvstore/badger/) for more details on the implementation of this +submodule. -See: [badger](../kvstore/badger/) for more details on the implementation of -this submodule. +## Note On External Writability -[badgerv4]: https://github.com/dgraph-io/badger +Any key-value store used by the tries should **not** be able to be externally +writeable in production. This opens the possibility to attacks where the writer +can modify the trie database and prove values that were not inserted. diff --git a/docs/smt.md b/docs/smt.md index 16220c5..7340ed6 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -1,23 +1,12 @@ -# smt - - - -- [smt](#smt) - - [Overview](#overview) - - [Implementation](#implementation) - - [Leaf Nodes](#leaf-nodes) - - [Inner Nodes](#inner-nodes) - - [Extension Nodes](#extension-nodes) - - [Lazy Nodes](#lazy-nodes) - - [Lazy Loading](#lazy-loading) - - [Visualisations](#visualisations) - - [General Trie Structure](#general-trie-structure) - - [Lazy Nodes](#lazy-nodes-1) +# smt + +- [nodes concatenated hashes.](#nodes-concatenated-hashes) - [Extension Nodes](#extension-nodes) - [Lazy Nodes](#lazy-nodes) - [Lazy Loading](#lazy-loading) - [Visualisations](#visualisations) - [General Trie Structure](#general-trie-structure) - [Lazy Nodes](#lazy-nodes-1) - [Paths](#paths) - [Visualisation](#visualisation) - [Values](#values) - [Nil values](#nil-values) - [Hashers \& Digests](#hashers--digests) + - [Hash Function Recommendations](#hash-function-recommendations) - [Roots](#roots) - [Proofs](#proofs) - [Verification](#verification) @@ -32,8 +21,6 @@ - [Data Loss](#data-loss) - [Sparse Merkle Sum Trie](#sparse-merkle-sum-trie) - - ## Overview Sparse Merkle Tries (SMTs) are efficient and secure data structures for storing @@ -77,10 +64,9 @@ constructor. ### Inner Nodes -Inner nodes represent a branch in the trie with two **non-nil** child nodes. - -The inner node has an internal `digest` which represents the hash of the child -nodes concatenated hashes. +Inner nodes represent a branch in the trie with two **non-nil** child nodes. The +inner node has an internal `digest` which represents the hash of the child nodes +concatenated hashes. - _Prefix_: `[]byte{1}` - _Digest_: `hash([]byte{1} + leftChild.digest + rightChild.digest)` @@ -319,8 +305,8 @@ described in the [implementation](#implementation) section. The following diagram represents the creation of a leaf node in an abstracted and simplified manner. -_Note: This diagram is not entirely accurate regarding the process of creating -a leaf node, but is a good representation of the process._ +_Note: This diagram is not entirely accurate regarding the process of creating a +leaf node, but is a good representation of the process._ ```mermaid graph TD @@ -355,20 +341,34 @@ graph TD VH --ValueHash-->L ``` +### Hash Function Recommendations + +Although any hash function that satisfies the `hash.Hash` interface can be used +to construct the trie, it is **strongly recommended** to use a hashing function +that provides the following properties: + +- **Collision resistance**: The hash function must be collision resistant. This + is needed in order for the inputs of the SMT to be unique. +- **Preimage resistance**: The hash function must be preimage resistant. This + is needed to protect against the Merkle tree construction attacks where + the attacker can modify unknown data. +- **Efficiency**: The hash function must be efficient, as it is used to compute + the hash of many nodes in the trie. + ## Roots The root of the tree is a slice of bytes. `MerkleRoot` is an alias for `[]byte`. -This design enables easily passing around the data (e.g. on-chain) -while maintaining primitive usage in different use cases (e.g. proofs). +This design enables easily passing around the data (e.g. on-chain) while +maintaining primitive usage in different use cases (e.g. proofs). `MerkleRoot` provides helpers, such as retrieving the `Sum() uint64` to -interface with data it captures. However, for the SMT it **always** panics, -as there is no sum. +interface with data it captures. However, for the SMT it **always** panics, as +there is no sum. ## Proofs -The `SparseMerkleProof` type contains the information required for inclusion -and exclusion proofs, depending on the key provided to the trie method +The `SparseMerkleProof` type contains the information required for inclusion and +exclusion proofs, depending on the key provided to the trie method `Prove(key []byte)` either an inclusion or exclusion proof will be generated. _NOTE: The inclusion and exclusion proof are the same type, just constructed @@ -409,29 +409,29 @@ using the `VerifyClosestProof` function which requires the proof and root hash of the trie. Since the `ClosestProof` method takes a hash as input, it is possible to place a -leaf in the trie according to the hash's path, if it is known. Depending on -the use case of this function this may expose a vulnerability. **It is not -intendend to be used as a general purpose proof mechanism**, but instead as a -**Commit and Reveal** mechanism, as detailed below. +leaf in the trie according to the hash's path, if it is known. Depending on the +use case of this function this may expose a vulnerability. **It is not intendend +to be used as a general purpose proof mechanism**, but instead as a **Commit and +Reveal** mechanism, as detailed below. #### Closest Proof Use Cases The `CloestProof` function is intended for use as a `commit & reveal` mechanism. Where there are two actors involved, the **prover** and **verifier**. -_NOTE: Throughout this document, `commitment` of the the trie's root hash is also -referred to as closing the trie, such that no more updates are made to it once -committed._ +_NOTE: Throughout this document, `commitment` of the the trie's root hash is +also referred to as closing the trie, such that no more updates are made to it +once committed._ Consider the following attack vector (**without** a commit prior to a reveal) into consideration: 1. The **verifier** picks the hash (i.e. a single branch) they intend to check -1. The **prover** inserts a leaf (i.e. a value) whose key (determined via the +2. The **prover** inserts a leaf (i.e. a value) whose key (determined via the hasher) has a longer common prefix than any other leaf in the trie. -1. Due to the deterministic nature of the `ClosestProof`, method this leaf will +3. Due to the deterministic nature of the `ClosestProof`, method this leaf will **always** be returned given the identified hash. -1. The **verifier** then verifies the revealed `ClosestProof`, which returns a +4. The **verifier** then verifies the revealed `ClosestProof`, which returns a branch the **prover** inserted after knowing which leaf was going to be checked. @@ -440,16 +440,16 @@ Consider the following normal flow (**with** a commit prior to reveal) as 1. The **prover** commits to the state of their trie by publishes their root hash, thereby _closing_ their trie and not being able to make further changes. -1. The **verifier** selects a hash to be used in the `commit & reveal` process +2. The **verifier** selects a hash to be used in the `commit & reveal` process that the **prover** must provide a closest proof for. -1. The **prover** utilises this hash and computes the `ClosestProof` on their +3. The **prover** utilises this hash and computes the `ClosestProof` on their _closed_ trie, producing a `ClosestProof`, thus revealing a deterministic, pseudo-random leaf that existed in the tree prior to commitment, yet -1. The **verifier** verifies the proof, in turn, verifying the commitment - made by the **prover** to the state of the trie in the first step. -1. The **prover** had no opportunity to insert a new leaf into the trie - after learning which hash the **verifier** was going to require a - `ClosestProof` for. +4. The **verifier** verifies the proof, in turn, verifying the commitment made + by the **prover** to the state of the trie in the first step. +5. The **prover** had no opportunity to insert a new leaf into the trie after + learning which hash the **verifier** was going to require a `ClosestProof` + for. ### Compression @@ -509,7 +509,8 @@ database. It's interface exposes numerous extra methods not used by the trie, However it can still be used as a node-store with both in-memory and persistent options. -See [badger-store.md](./badger-store.md.md) for the details of the implementation. +See [badger-store.md](./badger-store.md.md) for the details of the +implementation. ### Data Loss diff --git a/errors.go b/errors.go index b8d6bbe..8b89097 100644 --- a/errors.go +++ b/errors.go @@ -9,4 +9,7 @@ var ( ErrBadProof = errors.New("bad proof") // ErrKeyNotFound is returned when a key is not found in the tree. ErrKeyNotFound = errors.New("key not found") + // ErrInvalidClosestPath is returned when the path used in the ClosestProof + // method does not match the size of the trie's PathHasher + ErrInvalidClosestPath = errors.New("invalid path does not match path hasher size") ) diff --git a/hasher.go b/hasher.go index 570a03e..04e8e11 100644 --- a/hasher.go +++ b/hasher.go @@ -27,6 +27,8 @@ type PathHasher interface { type ValueHasher interface { // HashValue hashes value data to produce the digest stored in leaf node. HashValue([]byte) []byte + // ValueHashSize returns the length (in bytes) of digests produced by this hasher. + ValueHashSize() int } // trieHasher is a common hasher for all trie hashers (paths & values). @@ -56,8 +58,8 @@ func NewTrieHasher(hasher hash.Hash) *trieHasher { return &th } -func NewNilPathHasher(hasher hash.Hash) PathHasher { - return &nilPathHasher{hashSize: hasher.Size()} +func NewNilPathHasher(hasherSize int) PathHasher { + return &nilPathHasher{hashSize: hasherSize} } // Path returns the digest of a key produced by the path hasher @@ -76,6 +78,14 @@ func (vh *valueHasher) HashValue(data []byte) []byte { return vh.digestData(data) } +// ValueHashSize returns the length (in bytes) of digests produced by the value hasher +func (vh *valueHasher) ValueHashSize() int { + if vh.hasher == nil { + return 0 + } + return vh.hasher.Size() +} + // Path satisfies the PathHasher#Path interface func (n *nilPathHasher) Path(key []byte) []byte { return key[:n.hashSize] diff --git a/options.go b/options.go index 6164817..4556a4c 100644 --- a/options.go +++ b/options.go @@ -27,7 +27,7 @@ func NoHasherSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { spec := newTrieSpec(hasher, sumTrie) // Set a nil path hasher - opt := WithPathHasher(NewNilPathHasher(hasher)) + opt := WithPathHasher(NewNilPathHasher(hasher.Size())) opt(&spec) // Set a nil value hasher diff --git a/proofs.go b/proofs.go index a7eed6e..3b43dbc 100644 --- a/proofs.go +++ b/proofs.go @@ -58,6 +58,16 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { return fmt.Errorf("too many side nodes: got %d but max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) } + // Check that leaf data for non-membership proofs is a valid size. + lps := len(leafNodePrefix) + spec.ph.PathSize() + if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { + return fmt.Errorf( + "invalid non-membership leaf data size: got %d but min is %d", + len(proof.NonMembershipLeafData), + lps, + ) + } + // Verify that the non-membership leaf data is of the correct size. leafPathSize := len(leafNodePrefix) + spec.ph.PathSize() if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < leafPathSize { @@ -134,7 +144,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error { // Compact proofs: check that NumSideNodes is within the right range. if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 { - return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) + return fmt.Errorf( + "invalid number of side nodes: got %d, min is 0 and max is %d", + len(proof.SideNodes), + spec.ph.PathSize()*8, + ) } // Compact proofs: check that the length of the bit mask is as expected @@ -186,7 +200,24 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } +// GetValueHash returns the value hash of the closest proof. +func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { + if proof.ClosestValueHash == nil { + return nil + } + if spec.sumTrie { + return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBits] + } + return proof.ClosestValueHash +} + func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { + // ensure the proof length is the same size (in bytes) as the path + // hasher of the spec provided + if len(proof.Path) != spec.ph.PathSize() { + return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.ph.PathSize()) + } + // ensure the depth of the leaf node being proven is within the path size if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 { return fmt.Errorf("invalid depth: got %d, outside of [0, %d]", proof.Depth, spec.ph.PathSize()*8) @@ -232,6 +263,12 @@ type SparseCompactMerkleClosestProof struct { } func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) error { + // Ensure the proof length is the same size (in bytes) as the path + // hasher of the spec provided + if len(proof.Path) != spec.ph.PathSize() { + return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.ph.PathSize()) + } + // Do a basic sanity check on the proof on the fields of the proof specific to // the compact proof only. // @@ -247,7 +284,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro } for i, b := range proof.FlippedBits { if len(b) > maxSliceLen { - return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen) + return fmt.Errorf( + "invalid compressed flipped bit index %d: got length %d, max is %d]", + i, + bytesToInt(b), + maxSliceLen, + ) } } // perform a sanity check on the closest proof @@ -302,23 +344,30 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6 // VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf // found to have the closest path to the one provided to the proof structure -// -// TO_AUDITOR: This is akin to an inclusion proof with N (num flipped bits) exclusion -// proof wrapped into one and needs to be reviewed from an algorithm POV. func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TrieSpec) (bool, error) { if err := proof.validateBasic(spec); err != nil { return false, errors.Join(ErrBadProof, err) } - if !spec.sumTrie { - return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec) + // Create a new TrieSpec with a nil path hasher. + // Since the ClosestProof already contains a hashed path, double hashing it + // will invalidate the proof. + nilSpec := &TrieSpec{ + th: spec.th, + ph: NewNilPathHasher(spec.ph.PathSize()), + vh: spec.vh, + sumTrie: spec.sumTrie, + } + if !nilSpec.sumTrie { + return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec) } if proof.ClosestValueHash == nil { - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) } sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBits:] sum := binary.BigEndian.Uint64(sumBz) + valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBits] - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) } // verifyProofWithUpdates @@ -392,7 +441,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value } // VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof. -func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { +func VerifyCompactSumProof( + proof *SparseCompactMerkleProof, + root []byte, + key, value []byte, + sum uint64, + spec *TrieSpec, +) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { return false, errors.Join(ErrBadProof, err) diff --git a/proofs_test.go b/proofs_test.go index de4d3e9..dddd41f 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -28,7 +28,7 @@ func TestSparseMerkleProof_Marshal(t *testing.T) { require.Greater(t, len(bz2), 0) require.NotEqual(t, bz, bz2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) bz3, err := proof3.Marshal() require.NoError(t, err) require.NotNil(t, bz3) @@ -59,7 +59,7 @@ func TestSparseMerkleProof_Unmarshal(t *testing.T) { require.NoError(t, uproof2.Unmarshal(bz2)) require.Equal(t, proof2, uproof2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) bz3, err := proof3.Marshal() require.NoError(t, err) require.NotNil(t, bz3) @@ -91,7 +91,7 @@ func TestSparseCompactMerkleProof_Marshal(t *testing.T) { require.Greater(t, len(bz2), 0) require.NotEqual(t, bz, bz2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) compactProof3, err := CompactProof(proof3, trie.Spec()) require.NoError(t, err) bz3, err := compactProof3.Marshal() @@ -134,7 +134,7 @@ func TestSparseCompactMerkleProof_Unmarshal(t *testing.T) { require.NoError(t, err) require.Equal(t, proof2, uproof2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) compactProof3, err := CompactProof(proof3, trie.Spec()) require.NoError(t, err) bz3, err := compactProof3.Marshal() @@ -162,7 +162,7 @@ func setupTrie(t *testing.T) *SMT { return trie } -func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { +func randomizeProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { sideNodes[i] = make([]byte, len(proof.SideNodes[i])) @@ -174,7 +174,7 @@ func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { } } -func randomiseSumProof(proof *SparseMerkleProof) *SparseMerkleProof { +func randomizeSumProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBits) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index d2b06ac..cf5e41f 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -79,7 +79,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomizeSumProof(proof), + root, + []byte("testKey"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -98,7 +105,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomizeSumProof(proof), + root, + []byte("testKey2"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -129,7 +143,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 5, base) // wrong sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultEmptyValue, 0, base) // invalid proof + result, err = VerifySumProof( + randomizeSumProof(proof), + root, + []byte("testKey3"), + defaultEmptyValue, + 0, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) } @@ -204,7 +225,6 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smst := NewSparseMerkleSumTrie(smn, sha256.New()) - np := NoHasherSpec(sha256.New(), true) base := smst.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -227,14 +247,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -244,14 +264,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -265,7 +285,7 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -326,7 +346,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) @@ -352,7 +372,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -381,7 +401,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoHasherSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -419,7 +439,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoHasherSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } diff --git a/smt.go b/smt.go index fd1fac4..897707c 100644 --- a/smt.go +++ b/smt.go @@ -156,6 +156,19 @@ func (smt *SMT) update( return newLeaf, nil } // We insert an "extension" representing multiple single-branch inner nodes + var newInner *innerNode + if getPathBit(path, prefixLen) == leftChildBit { + newInner = &innerNode{ + leftChild: newLeaf, + rightChild: leaf, + } + } else { + newInner = &innerNode{ + leftChild: leaf, + rightChild: newLeaf, + } + } + // Determine if we need to insert an extension or a branch last := &node if depth < prefixLen { // note: this keeps path slice alive - GC inefficiency? @@ -163,25 +176,17 @@ func (smt *SMT) update( panic("invalid depth") } ext := extensionNode{ - path: path, + child: newInner, + path: path, pathBounds: [2]byte{ - byte(depth), - byte(prefixLen), + byte(depth), byte(prefixLen), }, } + // Dereference the last node to replace it with the extension node *last = &ext - last = &ext.child - } - if getPathBit(path, prefixLen) == leftChildBit { - *last = &innerNode{ - leftChild: newLeaf, - rightChild: leaf, - } } else { - *last = &innerNode{ - leftChild: leaf, - rightChild: newLeaf, - } + // Dereference the last node to replace it with the new inner node + *last = newInner } return node, nil } @@ -393,11 +398,16 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { // node is encountered, the traversal backsteps and flips the path bit for that // depth (ie tries left if it tried right and vice versa). This guarantees that // a proof of inclusion is found that has the most common bits with the path -// provided, biased to the longest common prefix +// provided, biased to the longest common prefix. func (smt *SMT) ProveClosest(path []byte) ( proof *SparseMerkleClosestProof, // proof of the key-value pair found err error, // the error value encountered ) { + // Ensure the path provided is the correct length for the path hasher. + if len(path) != smt.Spec().ph.PathSize() { + return nil, ErrInvalidClosestPath + } + workingPath := make([]byte, len(path)) copy(workingPath, path) var siblings []trieNode diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 61e3f80..cf2bf89 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -66,7 +66,7 @@ func TestSMT_Proof_Operations(t *testing.T) { result, err = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey"), []byte("testValue"), base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey"), []byte("testValue"), base) require.NoError(t, err) require.False(t, result) @@ -79,7 +79,7 @@ func TestSMT_Proof_Operations(t *testing.T) { result, err = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue"), base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey2"), []byte("testValue"), base) require.NoError(t, err) require.False(t, result) @@ -103,7 +103,7 @@ func TestSMT_Proof_Operations(t *testing.T) { result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultEmptyValue, base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.False(t, result) } @@ -178,7 +178,6 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smt := NewSparseMerkleTrie(smn, sha256.New()) - np := NoHasherSpec(sha256.New(), false) base := smt.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -201,14 +200,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -218,14 +217,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -239,7 +238,7 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -287,7 +286,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath := sha256.Sum256([]byte("testKey2")) @@ -304,7 +303,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoHasherSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath = sha256.Sum256([]byte("testKey4")) @@ -336,7 +335,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoHasherSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) } @@ -368,7 +367,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoHasherSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) }