Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Audit] Address Audit Issues and Suggestions #42

Merged
merged 6 commits into from
Apr 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: consolidate ClosestProof verification and remove the NilPathHas…
…her method
  • Loading branch information
h5law committed Mar 19, 2024
commit 4831b52ba5f12955b8ac94e0ae96ea189116d249
18 changes: 0 additions & 18 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package smt

import (
"hash"
)

// Option is a function that configures SparseMerkleTrie.
type Option func(*TrieSpec)

@@ -16,17 +12,3 @@ func WithPathHasher(ph PathHasher) Option {
func WithValueHasher(vh ValueHasher) Option {
return func(ts *TrieSpec) { ts.vh = vh }
}

// NoPrehashSpec returns a new TrieSpec that has a nil Value Hasher and a nil
// Path Hasher
// NOTE: This should only be used when values are already hashed and a path is
// used instead of a key during proof verification, otherwise these will be
// double hashed and produce an incorrect leaf digest invalidating the proof.
func NoPrehashSpec(hasher hash.Hash, sumTrie bool) *TrieSpec {
spec := newTrieSpec(hasher, sumTrie)
opt := WithPathHasher(newNilPathHasher(hasher.Size()))
opt(&spec)
opt = WithValueHasher(nil)
opt(&spec)
return &spec
}
66 changes: 54 additions & 12 deletions proofs.go
Original file line number Diff line number Diff line change
@@ -61,7 +61,11 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error {
// Check that leaf data for non-membership proofs is a valid size.
lps := len(leafPrefix) + spec.ph.PathSize()
if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps {
return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps)
return fmt.Errorf(
"invalid non-membership leaf data size: got %d but min is %d",
len(proof.NonMembershipLeafData),
lps,
)
}

// Check that all supplied sidenodes are the correct size.
@@ -133,7 +137,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error {

// Compact proofs: check that NumSideNodes is within the right range.
if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 {
return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8)
return fmt.Errorf(
"invalid number of side nodes: got %d, min is 0 and max is %d",
len(proof.SideNodes),
spec.ph.PathSize()*8,
)
}

// Compact proofs: check that the length of the bit mask is as expected
@@ -185,6 +193,17 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error {
return dec.Decode(proof)
}

// GetValueHash returns the value hash of the closest proof.
func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte {
if proof.ClosestValueHash == nil {
return nil
}
if spec.sumTrie {
return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize]
}
return proof.ClosestValueHash
}

func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error {
// ensure the depth of the leaf node being proven is within the path size
if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 {
@@ -246,7 +265,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro
}
for i, b := range proof.FlippedBits {
if len(b) > maxSliceLen {
return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen)
return fmt.Errorf(
"invalid compressed flipped bit index %d: got length %d, max is %d]",
i,
bytesToInt(b),
maxSliceLen,
)
}
}
// perform a sanity check on the closest proof
@@ -301,26 +325,38 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6

// VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf
// found to have the closest path to the one provided to the proof structure
//
// TO_AUDITOR: This is akin to an inclusion proof with N (num flipped bits) exclusion
// proof wrapped into one and needs to be reviewed from an algorithm POV.
func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TrieSpec) (bool, error) {
if err := proof.validateBasic(spec); err != nil {
return false, errors.Join(ErrBadProof, err)
}
if !spec.sumTrie {
return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec)
// Create a new TrieSpec with a nil path hasher - as the ClosestProof
// already contains a hashed path, double hashing it will invalidate the
// proof.
nilSpec := &TrieSpec{
th: spec.th,
ph: newNilPathHasher(spec.ph.PathSize()),
vh: spec.vh,
sumTrie: spec.sumTrie,
}
if !nilSpec.sumTrie {
return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec)
}
if proof.ClosestValueHash == nil {
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec)
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec)
}
sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:]
sum := binary.BigEndian.Uint64(sumBz)
valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize]
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec)
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec)
}

func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) {
func verifyProofWithUpdates(
proof *SparseMerkleProof,
root []byte,
key []byte,
value []byte,
spec *TrieSpec,
) (bool, [][][]byte, error) {
path := spec.ph.Path(key)

if err := proof.validateBasic(spec); err != nil {
@@ -384,7 +420,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value
}

// VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof.
func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) {
func VerifyCompactSumProof(
proof *SparseCompactMerkleProof,
root []byte,
key, value []byte,
sum uint64,
spec *TrieSpec,
) (bool, error) {
decompactedProof, err := DecompactProof(proof, spec)
if err != nil {
return false, errors.Join(ErrBadProof, err)
46 changes: 33 additions & 13 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
@@ -79,7 +79,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey"),
[]byte("testValue"),
5,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)

@@ -98,7 +105,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey2"),
[]byte("testValue"),
5,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)

@@ -129,7 +143,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey3"),
defaultValue,
0,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)
}
@@ -204,7 +225,6 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) {
func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
smn := simplemap.NewSimpleMap()
smst := NewSparseMerkleSumTrie(smn, sha256.New())
np := NoPrehashSpec(sha256.New(), true)
base := smst.Spec()
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
@@ -227,14 +247,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.Depth = -1
require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]")
result, err := VerifyClosestProof(proof, root, np)
result, err := VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.Depth = 257
require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -244,14 +264,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.FlippedBits[0] = -1
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.FlippedBits[0] = 9
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -265,7 +285,7 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
proof.validateBasic(base),
"invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)",
)
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -326,7 +346,7 @@ func TestSMST_ProveClosest(t *testing.T) {
ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields
})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true))
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.NoError(t, err)
require.True(t, result)

@@ -352,7 +372,7 @@ func TestSMST_ProveClosest(t *testing.T) {
ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields
})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true))
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
@@ -381,7 +401,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true))
result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
@@ -419,7 +439,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true))
result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
19 changes: 9 additions & 10 deletions smt_proofs_test.go
Original file line number Diff line number Diff line change
@@ -178,7 +178,6 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) {
func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
smn := simplemap.NewSimpleMap()
smt := NewSparseMerkleTrie(smn, sha256.New())
np := NoPrehashSpec(sha256.New(), false)
base := smt.Spec()
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
@@ -201,14 +200,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.Depth = -1
require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]")
result, err := VerifyClosestProof(proof, root, np)
result, err := VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.Depth = 257
require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -218,14 +217,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.FlippedBits[0] = -1
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.FlippedBits[0] = 9
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -239,7 +238,7 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
proof.validateBasic(base),
"invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)",
)
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
@@ -287,7 +286,7 @@ func TestSMT_ProveClosest(t *testing.T) {
checkClosestCompactEquivalence(t, proof, smt.Spec())
require.NotEqual(t, proof, &SparseMerkleClosestProof{})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false))
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.NoError(t, err)
require.True(t, result)
closestPath := sha256.Sum256([]byte("testKey2"))
@@ -304,7 +303,7 @@ func TestSMT_ProveClosest(t *testing.T) {
checkClosestCompactEquivalence(t, proof, smt.Spec())
require.NotEqual(t, proof, &SparseMerkleClosestProof{})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false))
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.NoError(t, err)
require.True(t, result)
closestPath = sha256.Sum256([]byte("testKey4"))
@@ -336,7 +335,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false))
result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec())
require.NoError(t, err)
require.True(t, result)
}
@@ -368,7 +367,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false))
result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec())
require.NoError(t, err)
require.True(t, result)
}