From d329ba550c62f95aa51f4deb183b82e4d9eff0b7 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Fri, 12 Jul 2024 12:44:38 +0200 Subject: [PATCH 01/10] refactor: SMST#Root(), #Sum(), & #Count() --- root.go | 47 ++++++++++++++++++++++++++++++++++------------- smst.go | 48 ++++++++++++++++++++++++++++-------------------- types.go | 13 +++++++++---- 3 files changed, 71 insertions(+), 37 deletions(-) diff --git a/root.go b/root.go index 73dc4af..0613fb8 100644 --- a/root.go +++ b/root.go @@ -15,7 +15,7 @@ const ( // MustSum returns the uint64 sum of the merkle root, it checks the length of the // merkle root and if it is no the same as the size of the SMST's expected // root hash it will panic. -func (r MerkleRoot) MustSum() uint64 { +func (r MerkleSumRoot) MustSum() uint64 { sum, err := r.Sum() if err != nil { panic(err) @@ -27,28 +27,49 @@ func (r MerkleRoot) MustSum() uint64 { // Sum returns the uint64 sum of the merkle root, it checks the length of the // merkle root and if it is no the same as the size of the SMST's expected // root hash it will return an error. -func (r MerkleRoot) Sum() (uint64, error) { - if len(r)%SmtRootSizeBytes == 0 { - return 0, fmt.Errorf("root#sum: not a merkle sum trie") +func (r MerkleSumRoot) Sum() (uint64, error) { + if len(r) != SmstRootSizeBytes { + return 0, fmt.Errorf("MerkleSumRoot#Sum: not a merkle sum trie") } - firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r)) + return getSum(r), nil +} - var sumBz [sumSizeBytes]byte - copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx]) - return binary.BigEndian.Uint64(sumBz[:]), nil +// MustCount returns the uint64 count of the merkle root, a cryptographically secure +// count of the number of non-empty leafs in the tree. +func (r MerkleSumRoot) MustCount() uint64 { + count, err := r.Count() + if err != nil { + panic(err) + } + + return count } // Count returns the uint64 count of the merkle root, a cryptographically secure // count of the number of non-empty leafs in the tree. -func (r MerkleRoot) Count() uint64 { - if len(r)%SmtRootSizeBytes == 0 { - panic("root#sum: not a merkle sum trie") +func (r MerkleSumRoot) Count() (uint64, error) { + if len(r) != SmstRootSizeBytes { + return 0, fmt.Errorf("MerkleSumRoot#Count: not a merkle sum trie") } - _, firstCountByteIdx := getFirstMetaByteIdx([]byte(r)) + return getCount(r), nil +} + +// getSum returns the sum of the node stored in the root. +func getSum(root []byte) uint64 { + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root) + + var sumBz [sumSizeBytes]byte + copy(sumBz[:], root[firstSumByteIdx:firstCountByteIdx]) + return binary.BigEndian.Uint64(sumBz[:]) +} + +// getCount returns the count of the node stored in the root. +func getCount(root []byte) uint64 { + _, firstCountByteIdx := getFirstMetaByteIdx(root) var countBz [countSizeBytes]byte - copy(countBz[:], []byte(r)[firstCountByteIdx:]) + copy(countBz[:], root[firstCountByteIdx:]) return binary.BigEndian.Uint64(countBz[:]) } diff --git a/smst.go b/smst.go index b07114b..09b5128 100644 --- a/smst.go +++ b/smst.go @@ -3,6 +3,7 @@ package smt import ( "bytes" "encoding/binary" + "fmt" "hash" "github.com/pokt-network/smt/kvstore" @@ -170,39 +171,46 @@ func (smst *SMST) Commit() error { } // Root returns the root hash of the trie with the total sum bytes appended -func (smst *SMST) Root() MerkleRoot { - return smst.SMT.Root() // [digest]+[binary sum] +func (smst *SMST) Root() MerkleSumRoot { + return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count] } -// Sum returns the sum of the entire trie stored in the root. +// MustSum returns the sum of the entire trie stored in the root. // If the tree is not a sum tree, it will panic. -func (smst *SMST) Sum() uint64 { - rootDigest := []byte(smst.Root()) +func (smst *SMST) MustSum() uint64 { + sum, err := smst.Sum() + if err != nil { + panic(err) + } + return sum +} +// Sum returns the sum of the entire trie stored in the root. +// If the tree is not a sum tree, it will panic. +func (smst *SMST) Sum() (uint64, error) { if !smst.Spec().sumTrie { - panic("SMST: not a merkle sum trie") + return 0, fmt.Errorf("SMST: not a merkle sum trie") } - firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest) + return smst.Root().Sum() +} - var sumBz [sumSizeBytes]byte - copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx]) - return binary.BigEndian.Uint64(sumBz[:]) +// MustCount returns the number of non-empty nodes in the entire trie stored in the root. +func (smst *SMST) MustCount() uint64 { + count, err := smst.Count() + if err != nil { + panic(err) + } + return count } // Count returns the number of non-empty nodes in the entire trie stored in the root. -func (smst *SMST) Count() uint64 { - rootDigest := []byte(smst.Root()) - +func (smst *SMST) Count() (uint64, error) { if !smst.Spec().sumTrie { - panic("SMST: not a merkle sum trie") + return 0, fmt.Errorf("SMST: not a merkle sum trie") } - _, firstCountByteIdx := getFirstMetaByteIdx(rootDigest) - - var countBz [countSizeBytes]byte - copy(countBz[:], rootDigest[firstCountByteIdx:]) - return binary.BigEndian.Uint64(countBz[:]) + return smst.Root().Count() } // getFirstMetaByteIdx returns the index of the first count byte and the first sum byte @@ -211,5 +219,5 @@ func (smst *SMST) Count() uint64 { func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) { firstCountByteIdx = len(data) - countSizeBytes firstSumByteIdx = firstCountByteIdx - sumSizeBytes - return + return firstSumByteIdx, firstCountByteIdx } diff --git a/types.go b/types.go index 0bed0a0..fa2ca4c 100644 --- a/types.go +++ b/types.go @@ -22,9 +22,12 @@ var ( defaultEmptyCount [countSizeBytes]byte ) -// MerkleRoot is a type alias for a byte slice returned from the Root method +// MerkleRoot is a type alias for a byte slice returned from SparseMerkleTrie#Root(). type MerkleRoot []byte +// MerkleSumRoot is a type alias for a byte slice returned from SparseMerkleSumTrie#Root(). +type MerkleSumRoot []byte + // A high-level interface that captures the behaviour of all types of nodes type trieNode interface { // Persisted returns a boolean to determine whether or not the node @@ -68,11 +71,13 @@ type SparseMerkleSumTrie interface { // Get descends the trie to access a value. Returns nil if key is not present. Get(key []byte) (data []byte, sum uint64, err error) // Root computes the Merkle root digest. - Root() MerkleRoot + Root() MerkleSumRoot // Sum computes the total sum of the Merkle trie - Sum() uint64 + Sum() (uint64, error) + MustSum() uint64 // Count returns the total number of non-empty leaves in the trie - Count() uint64 + Count() (uint64, error) + MustCount() uint64 // Prove computes a Merkle proof of inclusion or exclusion of a key. Prove(key []byte) (*SparseMerkleProof, error) // ProveClosest computes a Merkle proof of inclusion for a key in the trie From e3e6ed5a282ad3cc3b046713f426fcc6ea57c5e7 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Fri, 12 Jul 2024 13:03:16 +0200 Subject: [PATCH 02/10] chore: self Signed-off-by: Bryan White --- root.go | 6 ++++-- smst.go | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/root.go b/root.go index 0613fb8..faf315c 100644 --- a/root.go +++ b/root.go @@ -36,7 +36,8 @@ func (r MerkleSumRoot) Sum() (uint64, error) { } // MustCount returns the uint64 count of the merkle root, a cryptographically secure -// count of the number of non-empty leafs in the tree. +// count of the number of non-empty leafs in the tree. It panics if the root hash length +// does not match that of the SMST hasher. func (r MerkleSumRoot) MustCount() uint64 { count, err := r.Count() if err != nil { @@ -47,7 +48,8 @@ func (r MerkleSumRoot) MustCount() uint64 { } // Count returns the uint64 count of the merkle root, a cryptographically secure -// count of the number of non-empty leafs in the tree. +// count of the number of non-empty leafs in the tree. It returns an error if the root hash length +// does not match that of the SMST hasher. func (r MerkleSumRoot) Count() (uint64, error) { if len(r) != SmstRootSizeBytes { return 0, fmt.Errorf("MerkleSumRoot#Count: not a merkle sum trie") diff --git a/smst.go b/smst.go index 09b5128..5559c32 100644 --- a/smst.go +++ b/smst.go @@ -186,7 +186,7 @@ func (smst *SMST) MustSum() uint64 { } // Sum returns the sum of the entire trie stored in the root. -// If the tree is not a sum tree, it will panic. +// If the tree is not a sum tree, it will return an error. func (smst *SMST) Sum() (uint64, error) { if !smst.Spec().sumTrie { return 0, fmt.Errorf("SMST: not a merkle sum trie") @@ -196,6 +196,7 @@ func (smst *SMST) Sum() (uint64, error) { } // MustCount returns the number of non-empty nodes in the entire trie stored in the root. +// If the tree is not a sum tree, it will panic. func (smst *SMST) MustCount() uint64 { count, err := smst.Count() if err != nil { @@ -205,6 +206,7 @@ func (smst *SMST) MustCount() uint64 { } // Count returns the number of non-empty nodes in the entire trie stored in the root. +// If the tree is not a sum tree, it will return an error. func (smst *SMST) Count() (uint64, error) { if !smst.Spec().sumTrie { return 0, fmt.Errorf("SMST: not a merkle sum trie") From 9ab524a5626ea7fde608976d1b2d81b608879c50 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 15 Jul 2024 13:27:47 +0200 Subject: [PATCH 03/10] fix: tests --- root_test.go | 6 +++--- smst_example_test.go | 6 +++--- smst_test.go | 40 ++++++++++++++++++++-------------------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/root_test.go b/root_test.go index fb89236..b417ff5 100644 --- a/root_test.go +++ b/root_test.go @@ -58,9 +58,9 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { for i := uint64(0); i < 10; i++ { require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) } - require.NotNil(t, trie.Sum()) - require.EqualValues(t, 45, trie.Sum()) - require.EqualValues(t, 10, trie.Count()) + require.NotNil(t, trie.MustSum()) + require.EqualValues(t, 45, trie.MustSum()) + require.EqualValues(t, 10, trie.MustCount()) return } diff --git a/smst_example_test.go b/smst_example_test.go index c5d5ac6..018657d 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -28,7 +28,7 @@ func TestExampleSMST(t *testing.T) { _ = trie.Commit() // Calculate the total sum of the trie - _ = trie.Sum() // 20 + _ = trie.MustSum() // 20 // Generate a Merkle proof for "foo" proof1, _ := trie.Prove([]byte("foo")) @@ -52,8 +52,8 @@ func TestExampleSMST(t *testing.T) { require.False(t, valid_false1) // Verify the total sum of the trie - require.EqualValues(t, 20, trie.Sum()) + require.EqualValues(t, 20, trie.MustSum()) // Verify the number of non-empty leafs in the trie - require.EqualValues(t, 3, trie.Count()) + require.EqualValues(t, 3, trie.MustCount()) } diff --git a/smst_test.go b/smst_test.go index bd52cd4..5875678 100644 --- a/smst_test.go +++ b/smst_test.go @@ -359,7 +359,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Update([]byte("testKey"), []byte("testValue"), 5) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) // only root node - require.Equal(t, uint64(1), impl.Count()) + require.Equal(t, uint64(1), impl.MustCount()) } t.Run("delete 1", func(t *testing.T) { @@ -367,7 +367,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Delete([]byte("testKey")) require.NoError(t, err) require.Equal(t, 0, nodeCount(t)) - require.Equal(t, uint64(0), impl.Count()) + require.Equal(t, uint64(0), impl.MustCount()) }) t.Run("overwrite 1", func(t *testing.T) { @@ -375,7 +375,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Update([]byte("testKey"), []byte("testValue2"), 10) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) - require.Equal(t, uint64(1), impl.Count()) + require.Equal(t, uint64(1), impl.MustCount()) }) t.Run("overwrite and delete", func(t *testing.T) { @@ -383,12 +383,12 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Update([]byte("testKey"), []byte("testValue2"), 2) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) - require.Equal(t, uint64(1), impl.Count()) + require.Equal(t, uint64(1), impl.MustCount()) err = smst.Delete([]byte("testKey")) require.NoError(t, err) require.Equal(t, 0, nodeCount(t)) - require.Equal(t, uint64(0), impl.Count()) + require.Equal(t, uint64(0), impl.MustCount()) }) type testCase struct { @@ -436,7 +436,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { require.NoError(t, err, tci) } require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci) - require.Equal(t, uint64(tc.expectedLeafCount), impl.Count()) + require.Equal(t, uint64(tc.expectedLeafCount), impl.MustCount()) // Overwrite doesn't change node or leaf count for _, key := range tc.keys { @@ -444,7 +444,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { require.NoError(t, err, tci) } require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci) - require.Equal(t, uint64(tc.expectedLeafCount), impl.Count()) + require.Equal(t, uint64(tc.expectedLeafCount), impl.MustCount()) // Deletion removes all nodes except root for _, key := range tc.keys { @@ -452,13 +452,13 @@ func TestSMST_OrphanRemoval(t *testing.T) { require.NoError(t, err, tci) } require.Equal(t, 1, nodeCount(t), tci) - require.Equal(t, uint64(1), impl.Count()) + require.Equal(t, uint64(1), impl.MustCount()) // Deleting and re-inserting a persisted node doesn't change count require.NoError(t, smst.Delete([]byte("testKey"))) require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 10)) require.Equal(t, 1, nodeCount(t), tci) - require.Equal(t, uint64(1), impl.Count()) + require.Equal(t, uint64(1), impl.MustCount()) }) } } @@ -486,12 +486,12 @@ func TestSMST_TotalSum(t *testing.T) { rootCount := binary.BigEndian.Uint64(countBz) // Retrieve and compare the sum - sum := smst.Sum() + sum := smst.MustSum() require.Equal(t, sum, uint64(15)) require.Equal(t, sum, rootSum) // Retrieve and compare the count - count := smst.Count() + count := smst.MustCount() require.Equal(t, count, uint64(3)) require.Equal(t, count, rootCount) @@ -506,22 +506,22 @@ func TestSMST_TotalSum(t *testing.T) { // Check that the sum is correct after deleting a key err = smst.Delete([]byte("key1")) require.NoError(t, err) - sum = smst.Sum() + sum = smst.MustSum() require.Equal(t, sum, uint64(10)) // Check that the count is correct after deleting a key - count = smst.Count() + count = smst.MustCount() require.Equal(t, count, uint64(2)) // Check that the sum is correct after importing the trie require.NoError(t, smst.Commit()) root2 := smst.Root() smst = ImportSparseMerkleSumTrie(snm, sha256.New(), root2) - sum = smst.Sum() + sum = smst.MustSum() require.Equal(t, sum, uint64(10)) // Check that the count is correct after importing the trie - count = smst.Count() + count = smst.MustCount() require.Equal(t, count, uint64(2)) // Calculate the total sum of a larger trie @@ -532,11 +532,11 @@ func TestSMST_TotalSum(t *testing.T) { require.NoError(t, err) } require.NoError(t, smst.Commit()) - sum = smst.Sum() + sum = smst.MustSum() require.Equal(t, sum, uint64(49995000)) // Check that the count is correct after building a larger trie - count = smst.Count() + count = smst.MustCount() require.Equal(t, count, uint64(9999)) } @@ -584,7 +584,7 @@ func TestSMST_Retrieval(t *testing.T) { require.Equal(t, uint64(5), sum) root := smst.Root() - sum = smst.Sum() + sum = smst.MustSum() require.Equal(t, sum, uint64(15)) lazy := ImportSparseMerkleSumTrie(snm, sha256.New(), root, WithValueHasher(nil)) @@ -604,9 +604,9 @@ func TestSMST_Retrieval(t *testing.T) { require.Equal(t, []byte("value3"), value) require.Equal(t, uint64(5), sum) - sum = lazy.Sum() + sum = lazy.MustSum() require.Equal(t, sum, uint64(15)) - count := lazy.Count() + count := lazy.MustCount() require.Equal(t, count, uint64(3)) } From 5b8f70d9a7b20d101e5f9ed1346293cd9bddf012 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 09:24:39 +0200 Subject: [PATCH 04/10] refactor: simplify MerkleSumRoot & add #HasHashLength() --- root.go | 85 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/root.go b/root.go index faf315c..0e4db6b 100644 --- a/root.go +++ b/root.go @@ -3,20 +3,14 @@ package smt import ( "encoding/binary" "fmt" -) - -const ( - // These are intentionally exposed to allow for for testing and custom - // implementations of downstream applications. - SmtRootSizeBytes = 32 - SmstRootSizeBytes = SmtRootSizeBytes + sumSizeBytes + countSizeBytes + "hash" ) // MustSum returns the uint64 sum of the merkle root, it checks the length of the // merkle root and if it is no the same as the size of the SMST's expected // root hash it will panic. -func (r MerkleSumRoot) MustSum() uint64 { - sum, err := r.Sum() +func (root MerkleSumRoot) MustSum() uint64 { + sum, err := root.Sum() if err != nil { panic(err) } @@ -27,19 +21,19 @@ func (r MerkleSumRoot) MustSum() uint64 { // Sum returns the uint64 sum of the merkle root, it checks the length of the // merkle root and if it is no the same as the size of the SMST's expected // root hash it will return an error. -func (r MerkleSumRoot) Sum() (uint64, error) { - if len(r) != SmstRootSizeBytes { - return 0, fmt.Errorf("MerkleSumRoot#Sum: not a merkle sum trie") +func (root MerkleSumRoot) Sum() (uint64, error) { + if err := root.validateBasic(); err != nil { + return 0, err } - return getSum(r), nil + return root.sum(), nil } // MustCount returns the uint64 count of the merkle root, a cryptographically secure -// count of the number of non-empty leafs in the tree. It panics if the root hash length -// does not match that of the SMST hasher. -func (r MerkleSumRoot) MustCount() uint64 { - count, err := r.Count() +// count of the number of non-empty leafs in the tree. It panics if the root length +// is invalid. +func (root MerkleSumRoot) MustCount() uint64 { + count, err := root.Count() if err != nil { panic(err) } @@ -48,30 +42,55 @@ func (r MerkleSumRoot) MustCount() uint64 { } // Count returns the uint64 count of the merkle root, a cryptographically secure -// count of the number of non-empty leafs in the tree. It returns an error if the root hash length -// does not match that of the SMST hasher. -func (r MerkleSumRoot) Count() (uint64, error) { - if len(r) != SmstRootSizeBytes { - return 0, fmt.Errorf("MerkleSumRoot#Count: not a merkle sum trie") +// count of the number of non-empty leafs in the tree. It returns an error if the +// root length is invalid. +func (root MerkleSumRoot) Count() (uint64, error) { + if err := root.validateBasic(); err != nil { + return 0, err } - return getCount(r), nil + return root.count(), nil +} + +// HasHashLength returns true if the root hash (digest) length is the same as +// that of the size of the given hasher. +func (root MerkleSumRoot) HasHashLength(hasher hash.Hash) bool { + return root.length() == hasher.Size() } -// getSum returns the sum of the node stored in the root. -func getSum(root []byte) uint64 { +// validateBasic returns an error if the root (digest) length is not a power of two. +func (root MerkleSumRoot) validateBasic() error { + if !isPowerOfTwo(root.length()) { + return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length") + } + + return nil +} + +// length returns the length of the digest portion of the root. +func (root MerkleSumRoot) length() int { + return len(root) - countSizeBytes - sumSizeBytes +} + +// sum returns the sum of the node stored in the root. +func (root MerkleSumRoot) sum() uint64 { firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root) - var sumBz [sumSizeBytes]byte - copy(sumBz[:], root[firstSumByteIdx:firstCountByteIdx]) - return binary.BigEndian.Uint64(sumBz[:]) + return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx]) } -// getCount returns the count of the node stored in the root. -func getCount(root []byte) uint64 { +// count returns the count of the node stored in the root. +func (root MerkleSumRoot) count() uint64 { _, firstCountByteIdx := getFirstMetaByteIdx(root) - var countBz [countSizeBytes]byte - copy(countBz[:], root[firstCountByteIdx:]) - return binary.BigEndian.Uint64(countBz[:]) + return binary.BigEndian.Uint64(root[firstCountByteIdx:]) +} + +// isPowerOfTwo function returns true if the input n is a power of 2 +func isPowerOfTwo(n int) bool { + // A power of 2 has only one bit set in its binary representation + if n <= 0 { + return false + } + return (n & (n - 1)) == 0 } From 958f5343613d2a2853d499c2783715133cc66dcb Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 09:25:05 +0200 Subject: [PATCH 05/10] test: rewrite MerkleSumRoot tests --- root_test.go | 100 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 41 deletions(-) diff --git a/root_test.go b/root_test.go index b417ff5..ea37fdf 100644 --- a/root_test.go +++ b/root_test.go @@ -13,64 +13,82 @@ import ( "github.com/pokt-network/smt/kvstore/simplemap" ) -func TestMerkleRoot_TrieTypes(t *testing.T) { +func TestMerkleSumRoot_SumAndCountSuccess(t *testing.T) { tests := []struct { - desc string - sumTree bool - hasher hash.Hash - expectedPanic string + desc string + hasher hash.Hash }{ { - desc: "successfully: gets sum of sha256 hasher SMST", - sumTree: true, - hasher: sha256.New(), - expectedPanic: "", + desc: "sha256 hasher", + hasher: sha256.New(), }, { - desc: "successfully: gets sum of sha512 hasher SMST", - sumTree: true, - hasher: sha512.New(), - expectedPanic: "", + desc: "sha512 hasher", + hasher: sha512.New(), }, + } + + nodeStore := simplemap.NewSimpleMap() + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, nodeStore.ClearAll()) + }) + trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher) + for i := uint64(0); i < 10; i++ { + require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) + } + + sum, sumErr := trie.Sum() + require.NoError(t, sumErr) + + count, countErr := trie.Count() + require.NoError(t, countErr) + + require.EqualValues(t, uint64(45), sum) + require.EqualValues(t, uint64(10), count) + }) + } +} + +func TestMekleRoot_SumAndCountError(t *testing.T) { + tests := []struct { + desc string + hasher hash.Hash + }{ { - desc: "failure: panics for sha256 hasher SMT", - sumTree: false, - hasher: sha256.New(), - expectedPanic: "roo#sum: not a merkle sum trie", + desc: "sha256 hasher", + hasher: sha256.New(), }, { - desc: "failure: panics for sha512 hasher SMT", - sumTree: false, - hasher: sha512.New(), - expectedPanic: "roo#sum: not a merkle sum trie", + desc: "sha512 hasher", + hasher: sha512.New(), }, } nodeStore := simplemap.NewSimpleMap() - for _, tt := range tests { - tt := tt - t.Run(tt.desc, func(t *testing.T) { + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { t.Cleanup(func() { require.NoError(t, nodeStore.ClearAll()) }) - if tt.sumTree { - trie := smt.NewSparseMerkleSumTrie(nodeStore, tt.hasher) - for i := uint64(0); i < 10; i++ { - require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) - } - require.NotNil(t, trie.MustSum()) - require.EqualValues(t, 45, trie.MustSum()) - require.EqualValues(t, 10, trie.MustCount()) - - return - } - trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher) - for i := 0; i < 10; i++ { - require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)))) - } - if panicStr := recover(); panicStr != nil { - require.Equal(t, tt.expectedPanic, panicStr) + trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher) + for i := uint64(0); i < 10; i++ { + require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) } + + root := trie.Root() + + // Mangle the root bytes. + root = root[:len(root)-1] + + sum, sumErr := root.Sum() + require.Error(t, sumErr) + require.Equal(t, uint64(0), sum) + + count, countErr := root.Count() + require.Error(t, countErr) + require.Equal(t, uint64(0), count) }) } } From 934fc199c8b3903dc4bf8fea6c70e1f11f8056ab Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 09:43:28 +0200 Subject: [PATCH 06/10] refactor: #HasHashLength(hash.Hash) to #HasDigestSize(int) --- root.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/root.go b/root.go index 0e4db6b..3d5f841 100644 --- a/root.go +++ b/root.go @@ -3,7 +3,6 @@ package smt import ( "encoding/binary" "fmt" - "hash" ) // MustSum returns the uint64 sum of the merkle root, it checks the length of the @@ -52,10 +51,10 @@ func (root MerkleSumRoot) Count() (uint64, error) { return root.count(), nil } -// HasHashLength returns true if the root hash (digest) length is the same as +// HasDigestSize returns true if the root hash (digest) length is the same as // that of the size of the given hasher. -func (root MerkleSumRoot) HasHashLength(hasher hash.Hash) bool { - return root.length() == hasher.Size() +func (root MerkleSumRoot) HasDigestSize(size int) bool { + return root.length() == size } // validateBasic returns an error if the root (digest) length is not a power of two. From 96df23c426989d69df44547c402fa2bf561867ea Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 09:49:36 +0200 Subject: [PATCH 07/10] refactor: #length() to #DigestSize() --- root.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/root.go b/root.go index 3d5f841..06da067 100644 --- a/root.go +++ b/root.go @@ -51,26 +51,26 @@ func (root MerkleSumRoot) Count() (uint64, error) { return root.count(), nil } +// DigestSize returns the length of the digest portion of the root. +func (root MerkleSumRoot) DigestSize() int { + return len(root) - countSizeBytes - sumSizeBytes +} + // HasDigestSize returns true if the root hash (digest) length is the same as // that of the size of the given hasher. func (root MerkleSumRoot) HasDigestSize(size int) bool { - return root.length() == size + return root.DigestSize() == size } // validateBasic returns an error if the root (digest) length is not a power of two. func (root MerkleSumRoot) validateBasic() error { - if !isPowerOfTwo(root.length()) { + if !isPowerOfTwo(root.DigestSize()) { return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length") } return nil } -// length returns the length of the digest portion of the root. -func (root MerkleSumRoot) length() int { - return len(root) - countSizeBytes - sumSizeBytes -} - // sum returns the sum of the node stored in the root. func (root MerkleSumRoot) sum() uint64 { firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root) From e4995408f5ce4bea37714fefdbf9ed123262ffa9 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 09:55:07 +0200 Subject: [PATCH 08/10] chore: self-review improvements Signed-off-by: Bryan White --- root.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/root.go b/root.go index 06da067..618bbd9 100644 --- a/root.go +++ b/root.go @@ -56,13 +56,13 @@ func (root MerkleSumRoot) DigestSize() int { return len(root) - countSizeBytes - sumSizeBytes } -// HasDigestSize returns true if the root hash (digest) length is the same as +// HasDigestSize returns true if the root digest size is the same as // that of the size of the given hasher. func (root MerkleSumRoot) HasDigestSize(size int) bool { return root.DigestSize() == size } -// validateBasic returns an error if the root (digest) length is not a power of two. +// validateBasic returns an error if the root digest size is not a power of two. func (root MerkleSumRoot) validateBasic() error { if !isPowerOfTwo(root.DigestSize()) { return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length") From a7a3f71960af7f79a8ad9ba26079305e9e03fdce Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 10:07:29 +0200 Subject: [PATCH 09/10] refactor: move applicable root methods to MerkleRoot and inherit in MerkleSumRoot --- root.go | 70 ++++++++++++++++++++++++++++---------------------------- types.go | 2 +- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/root.go b/root.go index 618bbd9..29eb43d 100644 --- a/root.go +++ b/root.go @@ -5,33 +5,10 @@ import ( "fmt" ) -// MustSum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will panic. -func (root MerkleSumRoot) MustSum() uint64 { - sum, err := root.Sum() - if err != nil { - panic(err) - } - - return sum -} - -// Sum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will return an error. -func (root MerkleSumRoot) Sum() (uint64, error) { - if err := root.validateBasic(); err != nil { - return 0, err - } - - return root.sum(), nil -} - // MustCount returns the uint64 count of the merkle root, a cryptographically secure // count of the number of non-empty leafs in the tree. It panics if the root length // is invalid. -func (root MerkleSumRoot) MustCount() uint64 { +func (root MerkleRoot) MustCount() uint64 { count, err := root.Count() if err != nil { panic(err) @@ -43,7 +20,7 @@ func (root MerkleSumRoot) MustCount() uint64 { // Count returns the uint64 count of the merkle root, a cryptographically secure // count of the number of non-empty leafs in the tree. It returns an error if the // root length is invalid. -func (root MerkleSumRoot) Count() (uint64, error) { +func (root MerkleRoot) Count() (uint64, error) { if err := root.validateBasic(); err != nil { return 0, err } @@ -52,18 +29,41 @@ func (root MerkleSumRoot) Count() (uint64, error) { } // DigestSize returns the length of the digest portion of the root. -func (root MerkleSumRoot) DigestSize() int { +func (root MerkleRoot) DigestSize() int { return len(root) - countSizeBytes - sumSizeBytes } // HasDigestSize returns true if the root digest size is the same as // that of the size of the given hasher. -func (root MerkleSumRoot) HasDigestSize(size int) bool { +func (root MerkleRoot) HasDigestSize(size int) bool { return root.DigestSize() == size } +// MustSum returns the uint64 sum of the merkle root, it checks the length of the +// merkle root and if it is no the same as the size of the SMST's expected +// root hash it will panic. +func (root MerkleSumRoot) MustSum() uint64 { + sum, err := root.Sum() + if err != nil { + panic(err) + } + + return sum +} + +// Sum returns the uint64 sum of the merkle root, it checks the length of the +// merkle root and if it is no the same as the size of the SMST's expected +// root hash it will return an error. +func (root MerkleSumRoot) Sum() (uint64, error) { + if err := root.validateBasic(); err != nil { + return 0, err + } + + return root.sum(), nil +} + // validateBasic returns an error if the root digest size is not a power of two. -func (root MerkleSumRoot) validateBasic() error { +func (root MerkleRoot) validateBasic() error { if !isPowerOfTwo(root.DigestSize()) { return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length") } @@ -71,6 +71,13 @@ func (root MerkleSumRoot) validateBasic() error { return nil } +// count returns the count of the node stored in the root. +func (root MerkleRoot) count() uint64 { + _, firstCountByteIdx := getFirstMetaByteIdx(root) + + return binary.BigEndian.Uint64(root[firstCountByteIdx:]) +} + // sum returns the sum of the node stored in the root. func (root MerkleSumRoot) sum() uint64 { firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root) @@ -78,13 +85,6 @@ func (root MerkleSumRoot) sum() uint64 { return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx]) } -// count returns the count of the node stored in the root. -func (root MerkleSumRoot) count() uint64 { - _, firstCountByteIdx := getFirstMetaByteIdx(root) - - return binary.BigEndian.Uint64(root[firstCountByteIdx:]) -} - // isPowerOfTwo function returns true if the input n is a power of 2 func isPowerOfTwo(n int) bool { // A power of 2 has only one bit set in its binary representation diff --git a/types.go b/types.go index fa2ca4c..641c96a 100644 --- a/types.go +++ b/types.go @@ -26,7 +26,7 @@ var ( type MerkleRoot []byte // MerkleSumRoot is a type alias for a byte slice returned from SparseMerkleSumTrie#Root(). -type MerkleSumRoot []byte +type MerkleSumRoot = MerkleRoot // A high-level interface that captures the behaviour of all types of nodes type trieNode interface { From b1fa1116f5bc46d2aeae7c5516bcab865a1af87c Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 16 Jul 2024 13:51:44 +0200 Subject: [PATCH 10/10] Revert "refactor: move applicable root methods to MerkleRoot and inherit in MerkleSumRoot" This reverts commit a7a3f71960af7f79a8ad9ba26079305e9e03fdce. --- root.go | 70 ++++++++++++++++++++++++++++---------------------------- types.go | 2 +- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/root.go b/root.go index 29eb43d..618bbd9 100644 --- a/root.go +++ b/root.go @@ -5,10 +5,33 @@ import ( "fmt" ) +// MustSum returns the uint64 sum of the merkle root, it checks the length of the +// merkle root and if it is no the same as the size of the SMST's expected +// root hash it will panic. +func (root MerkleSumRoot) MustSum() uint64 { + sum, err := root.Sum() + if err != nil { + panic(err) + } + + return sum +} + +// Sum returns the uint64 sum of the merkle root, it checks the length of the +// merkle root and if it is no the same as the size of the SMST's expected +// root hash it will return an error. +func (root MerkleSumRoot) Sum() (uint64, error) { + if err := root.validateBasic(); err != nil { + return 0, err + } + + return root.sum(), nil +} + // MustCount returns the uint64 count of the merkle root, a cryptographically secure // count of the number of non-empty leafs in the tree. It panics if the root length // is invalid. -func (root MerkleRoot) MustCount() uint64 { +func (root MerkleSumRoot) MustCount() uint64 { count, err := root.Count() if err != nil { panic(err) @@ -20,7 +43,7 @@ func (root MerkleRoot) MustCount() uint64 { // Count returns the uint64 count of the merkle root, a cryptographically secure // count of the number of non-empty leafs in the tree. It returns an error if the // root length is invalid. -func (root MerkleRoot) Count() (uint64, error) { +func (root MerkleSumRoot) Count() (uint64, error) { if err := root.validateBasic(); err != nil { return 0, err } @@ -29,41 +52,18 @@ func (root MerkleRoot) Count() (uint64, error) { } // DigestSize returns the length of the digest portion of the root. -func (root MerkleRoot) DigestSize() int { +func (root MerkleSumRoot) DigestSize() int { return len(root) - countSizeBytes - sumSizeBytes } // HasDigestSize returns true if the root digest size is the same as // that of the size of the given hasher. -func (root MerkleRoot) HasDigestSize(size int) bool { +func (root MerkleSumRoot) HasDigestSize(size int) bool { return root.DigestSize() == size } -// MustSum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will panic. -func (root MerkleSumRoot) MustSum() uint64 { - sum, err := root.Sum() - if err != nil { - panic(err) - } - - return sum -} - -// Sum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will return an error. -func (root MerkleSumRoot) Sum() (uint64, error) { - if err := root.validateBasic(); err != nil { - return 0, err - } - - return root.sum(), nil -} - // validateBasic returns an error if the root digest size is not a power of two. -func (root MerkleRoot) validateBasic() error { +func (root MerkleSumRoot) validateBasic() error { if !isPowerOfTwo(root.DigestSize()) { return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length") } @@ -71,13 +71,6 @@ func (root MerkleRoot) validateBasic() error { return nil } -// count returns the count of the node stored in the root. -func (root MerkleRoot) count() uint64 { - _, firstCountByteIdx := getFirstMetaByteIdx(root) - - return binary.BigEndian.Uint64(root[firstCountByteIdx:]) -} - // sum returns the sum of the node stored in the root. func (root MerkleSumRoot) sum() uint64 { firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root) @@ -85,6 +78,13 @@ func (root MerkleSumRoot) sum() uint64 { return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx]) } +// count returns the count of the node stored in the root. +func (root MerkleSumRoot) count() uint64 { + _, firstCountByteIdx := getFirstMetaByteIdx(root) + + return binary.BigEndian.Uint64(root[firstCountByteIdx:]) +} + // isPowerOfTwo function returns true if the input n is a power of 2 func isPowerOfTwo(n int) bool { // A power of 2 has only one bit set in its binary representation diff --git a/types.go b/types.go index 641c96a..fa2ca4c 100644 --- a/types.go +++ b/types.go @@ -26,7 +26,7 @@ var ( type MerkleRoot []byte // MerkleSumRoot is a type alias for a byte slice returned from SparseMerkleSumTrie#Root(). -type MerkleSumRoot = MerkleRoot +type MerkleSumRoot []byte // A high-level interface that captures the behaviour of all types of nodes type trieNode interface {