Skip to content

Commit

Permalink
Moved trie spec into its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
Olshansk committed Feb 12, 2024
1 parent ee67740 commit c263cd7
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 309 deletions.
17 changes: 12 additions & 5 deletions hasher.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smt

import (
"encoding/binary"
"hash"
)

Expand Down Expand Up @@ -93,8 +94,8 @@ func (th *trieHasher) digestData(data []byte) []byte {
return digest
}

// digestLeaf returns the encoded leaf data as well as its hash (i.e. digest)
func (th *trieHasher) digestLeaf(path, data []byte) (digest, value []byte) {
// digestLeafNode returns the encoded leaf data as well as its hash (i.e. digest)
func (th *trieHasher) digestLeafNode(path, data []byte) (digest, value []byte) {
value = encodeLeafNode(path, data)
digest = th.digestData(value)
return
Expand All @@ -106,8 +107,8 @@ func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value
return
}

func (th *trieHasher) digestSumLeaf(path, leafData []byte) (digest, value []byte) {
value = encodeLeafNode(path, leafData)
func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) {
value = encodeLeafNode(path, data)
digest = th.digestData(value)
digest = append(digest, value[len(value)-sumSizeBits:]...)
return
Expand All @@ -126,7 +127,13 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) {
return
}

func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte) {
func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) {
// Extract the sum from the encoded node data
var sumBz [sumSizeBits]byte
copy(sumBz[:], data[len(data)-sumSizeBits:])
binary.BigEndian.PutUint64(sumBz[:], sum)

// Extract the left and right children
dataWithoutSum := data[:len(data)-sumSizeBits]
leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBits]
rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:]
Expand Down
39 changes: 0 additions & 39 deletions node_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,44 +50,6 @@ func isInnerNode(data []byte) bool {
return bytes.Equal(data[:prefixLen], innerNodePrefix)
}

// parseLeafNode parses a leafNode into its components
func parseLeafNode(data []byte, ph PathHasher) (path, value []byte) {
// panics if not a leaf node
checkPrefix(data, leafNodePrefix)

path = data[prefixLen : prefixLen+ph.PathSize()]
value = data[prefixLen+ph.PathSize():]
return
}

// parseExtNode parses an extNode into its components
func parseExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte) {
// panics if not an extension node
checkPrefix(data, extNodePrefix)

// +2 represents the length of the pathBounds
pathBounds = data[prefixLen : prefixLen+2]
path = data[prefixLen+2 : prefixLen+2+ph.PathSize()]
childData = data[prefixLen+2+ph.PathSize():]
return
}

// parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data
func parseSumExtNode(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSizeBits]byte) {
// panics if not an extension node
checkPrefix(data, extNodePrefix)

// Extract the sum from the encoded node data
var sumBz [sumSizeBits]byte
copy(sumBz[:], data[len(data)-sumSizeBits:])

// +2 represents the length of the pathBounds
pathBounds = data[prefixLen : prefixLen+2]
path = data[prefixLen+2 : prefixLen+2+ph.PathSize()]
childData = data[prefixLen+2+ph.PathSize() : len(data)-sumSizeBits]
return
}

// encodeLeafNode encodes leaf nodes. This function applies to both the SMT and
// SMST since the weight of the node is appended to the end of the valueHash.
func encodeLeafNode(path, leafData []byte) (data []byte) {
Expand Down Expand Up @@ -130,7 +92,6 @@ func encodeSumInnerNode(leftData, rightData []byte) (data []byte) {

// encodeSumExtensionNode encodes the data of a sum extension nodes
func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) {

// Compute the sum of the current node
var sum [sumSizeBits]byte
copy(sum[:], childData[len(childData)-sumSizeBits:])
Expand Down
46 changes: 27 additions & 19 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,33 +49,34 @@ func (proof *SparseMerkleProof) Unmarshal(bz []byte) error {
return dec.Decode(proof)
}

// validateBasic performs basic sanity check on the proof so that a malicious
// proof cannot cause the verifier to fatally exit (e.g. due to an index
// out-of-range error) or cause a CPU DoS attack.
func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error {
// Do a basic sanity check on the proof, so that a malicious proof cannot
// cause the verifier to fatally exit (e.g. due to an index out-of-range
// error) or cause a CPU DoS attack.

// Check that the number of supplied sidenodes does not exceed the maximum possible.
// Verify the number of supplied sideNodes does not exceed the possible maximum.
if len(proof.SideNodes) > spec.ph.PathSize()*8 {
return fmt.Errorf("too many side nodes: got %d but max is %d", len(proof.SideNodes), spec.ph.PathSize()*8)
}
// Check that leaf data for non-membership proofs is a valid size.
lps := len(leafNodePrefix) + spec.ph.PathSize()
if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps {
return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps)
}

// Check that all supplied sidenodes are the correct size.
for _, v := range proof.SideNodes {
if len(v) != spec.hashSize() {
return fmt.Errorf("invalid side node size: got %d but want %d", len(v), spec.hashSize())
}
// Verify that the non-membership leaf data is of the correct size.
leafPathSize := len(leafNodePrefix) + spec.ph.PathSize()
if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < leafPathSize {
return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), leafPathSize)
}

// Check that the sibling data hashes to the first side node if not nil
if proof.SiblingData == nil || len(proof.SideNodes) == 0 {
return nil
}
siblingHash := hashPreimage(spec, proof.SiblingData)

// Check that all supplied sideNodes are the correct size.
for _, sideNodeValue := range proof.SideNodes {
if len(sideNodeValue) != spec.hashSize() {
return fmt.Errorf("invalid side node size: got %d but want %d", len(sideNodeValue), spec.hashSize())
}
}

siblingHash := spec.hashPreimage(proof.SiblingData)
if eq := bytes.Equal(proof.SideNodes[0], siblingHash); !eq {
return fmt.Errorf("invalid sibling data hash: got %x but want %x", siblingHash, proof.SideNodes[0])
}
Expand Down Expand Up @@ -320,7 +321,13 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec)
}

func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) {
// verifyProofWithUpdates
func verifyProofWithUpdates(
proof *SparseMerkleProof,
root, key, value []byte,
spec *TrieSpec,
) (bool, [][][]byte, error) {
// Retrieve the trie path for the give key
path := spec.ph.Path(key)

if err := proof.validateBasic(spec); err != nil {
Expand All @@ -336,7 +343,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v
currentHash = spec.placeholder()
} else { // Leaf is an unrelated leaf.
var actualPath, valueHash []byte
actualPath, valueHash = parseLeafNode(proof.NonMembershipLeafData, spec.ph)
actualPath, valueHash = spec.parseLeafNode(proof.NonMembershipLeafData)
if bytes.Equal(actualPath, path) {
// This is not an unrelated leaf; non-membership proof failed.
return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf"))
Expand All @@ -347,7 +354,8 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
} else { // Membership proof.
} else {
// Membership proof.
valueHash := spec.valueHash(value)
currentHash, currentData = spec.digestLeaf(path, valueHash)
update := make([][]byte, 2)
Expand Down
82 changes: 59 additions & 23 deletions smst_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,31 +82,67 @@ func exportToCSV(
nodeStore kvstore.MapStore,
) {
t.Helper()
rootHash := smst.Root()
rootNode, err := nodeStore.Get(rootHash)
/*
rootHash := smst.Root()
rootNode, err := nodeStore.Get(rootHash)
require.NoError(t, err)
leftData, rightData := smst.Spec().th.parseSumInnerNode(rootNode)
leftChild, err := nodeStore.Get(leftData)
require.NoError(t, err)
rightChild, err := nodeStore.Get(rightData)
require.NoError(t, err)
fmt.Println("Prefix", "isExt", "isLeaf", "isInner")
// false false true
fmt.Println("root", isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode), rootNode)
fmt.Println()
// false false false
fmt.Println("left", isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild), leftChild)
fmt.Println()
// false false false
fmt.Println("right", isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild), rightChild)
fmt.Println()
*/

/*
for key, value := range innerMap {
v, s, err := smst.Get([]byte(key))
require.NoError(t, err)
fmt.Println(v, s, []byte(key))
fmt.Println(value)
fmt.Println("")
fmt.Println("")
}
*/

helper(t, smst, nodeStore, smst.Root())
}

func helper(t *testing.T, smst SparseMerkleSumTrie, nodeStore kvstore.MapStore, nodeDigest []byte) {
t.Helper()

node, err := nodeStore.Get(nodeDigest)
require.NoError(t, err)

leftChild, rightChild := smst.Spec().th.parseSumInnerNode(rootNode)
fmt.Println("Prefix", "isExt", "isLeaf", "isInner")
fmt.Println(rootNode[:1], isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode))
fmt.Println(leftChild[:1], isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild))
fmt.Println(rightChild[:1], isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild))
// path, value := parseLeafNode(rightChild, smst.Spec().ph)
// path2, value2 := parseLeafNode(leftChild, smst.Spec().ph)
// fmt.Println(path, "~~~", value, "~~~", path2, "~~~", value2)

for key, value := range innerMap {
v, s, err := smst.Get([]byte(key))
require.NoError(t, err)
fmt.Println(v, s)
fmt.Println(value)
fmt.Println("")
fmt.Println("")
fmt.Println()
if isExtNode(node) {
pathBounds, path, childData, sum := smst.Spec().parseSumExtNode(node)
fmt.Println("ext node sum", sum)
fmt.Println(pathBounds, path)
helper(t, smst, nodeStore, childData)
return
} else if isLeafNode(node) {
path, value := smst.Spec().parseLeafNode(node)
fmt.Println("leaf node sum", 0)
fmt.Println(path, value)
} else if isInnerNode(node) {
leftData, rightData, sum := smst.Spec().th.parseSumInnerNode(node)
fmt.Println("inner node sum", sum)
helper(t, smst, nodeStore, leftData)
helper(t, smst, nodeStore, rightData)
}

// Export the trie to a CSV file
// err := smt.ExportToCSV("export.csv")
// if err != nil {
// panic(err)
// }
// v, s, err := smst.Get([]byte(key))
// require.NoError(t, err)
// require.Equal(t, []byte(value), v)
// require.Equal(t, sum, s)
}
2 changes: 1 addition & 1 deletion smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestSMST_Proof_Operations(t *testing.T) {
binary.BigEndian.PutUint64(sum[:], 5)
tval := base.valueHash([]byte("testValue"))
tval = append(tval, sum[:]...)
_, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval)
_, leafData := base.th.digestSumLeafNode(base.ph.Path([]byte("testKey2")), tval)
proof = &SparseMerkleProof{
SideNodes: proof.SideNodes,
NonMembershipLeafData: leafData,
Expand Down
12 changes: 6 additions & 6 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func (smt *SMT) ProveClosest(path []byte) (
return proof, nil
}

// resolveLazy resolves resolves a stub into a cached node
// resolveLazy resolves a lazy note into a cached node depending on the tree type
func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) {
stub, ok := node.(*lazyNode)
if !ok {
Expand Down Expand Up @@ -550,15 +550,15 @@ func (smt *SMT) resolveNode(digest []byte) (trieNode, error) {

// Return the appropriate node type based on the first byte of the data
if isLeafNode(data) {
path, valueHash := parseLeafNode(data, smt.ph)
path, valueHash := smt.parseLeafNode(data)
return &leafNode{
path: path,
valueHash: valueHash,
persisted: true,
digest: digest,
}, nil
} else if isExtNode(data) {
pathBounds, path, childData := parseExtNode(data, smt.ph)
pathBounds, path, childData := smt.parseExtNode(data)
return &extensionNode{
path: path,
pathBounds: [2]byte(pathBounds),
Expand Down Expand Up @@ -595,15 +595,15 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) {

// Return the appropriate node type based on the first byte of the data
if isLeafNode(data) {
path, valueHash := parseLeafNode(data, smt.ph)
path, valueHash := smt.parseLeafNode(data)
return &leafNode{
path: path,
valueHash: valueHash,
persisted: true,
digest: digest,
}, nil
} else if isExtNode(data) {
pathBounds, path, childData, _ := parseSumExtNode(data, smt.ph)
pathBounds, path, childData, _ := smt.parseSumExtNode(data)
return &extensionNode{
path: path,
pathBounds: [2]byte(pathBounds),
Expand All @@ -612,7 +612,7 @@ func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) {
digest: digest,
}, nil
} else if isInnerNode(data) {
leftData, rightData := smt.th.parseSumInnerNode(data)
leftData, rightData, _ := smt.th.parseSumInnerNode(data)
return &innerNode{
leftChild: &lazyNode{leftData},
rightChild: &lazyNode{rightData},
Expand Down
2 changes: 1 addition & 1 deletion smt_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestSMT_Proof_Operations(t *testing.T) {
require.False(t, result)

// Try proving a default value for a non-default leaf.
_, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue")))
_, leafData := base.th.digestLeafNode(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue")))
proof = &SparseMerkleProof{
SideNodes: proof.SideNodes,
NonMembershipLeafData: leafData,
Expand Down
Loading

0 comments on commit c263cd7

Please sign in to comment.