From 55db21b65fe9d0a108685e95461d7ef8f5cd4928 Mon Sep 17 00:00:00 2001 From: dmfxyz Date: Fri, 22 Mar 2024 23:32:38 -0400 Subject: [PATCH] initial optimization pass at getProof --- src/CompleteMerkle.sol | 69 +++++++++--------- src/test/CompleteMerkle.t.sol | 129 ---------------------------------- 2 files changed, 32 insertions(+), 166 deletions(-) delete mode 100644 src/test/CompleteMerkle.t.sol diff --git a/src/CompleteMerkle.sol b/src/CompleteMerkle.sol index 34cbca7..8d02460 100644 --- a/src/CompleteMerkle.sol +++ b/src/CompleteMerkle.sol @@ -10,21 +10,19 @@ contract CompleteMerkle { * HASHING FUNCTION * * */ - - function hashLeafPairs(bytes32 left, bytes32 right) public pure returns (bytes32 _hash) { assembly { - switch lt(left, right) - case 0 { - mstore(0x0, right) - mstore(0x20, left) - } - default { - mstore(0x0, left) - mstore(0x20, right) - } - _hash := keccak256(0x0, 0x40) + switch lt(left, right) + case 0 { + mstore(0x0, right) + mstore(0x20, left) + } + default { + mstore(0x0, left) + mstore(0x20, right) } + _hash := keccak256(0x0, 0x40) + } } function initTree(bytes32[] memory data) private pure returns (bytes32[] memory) { @@ -44,6 +42,10 @@ contract CompleteMerkle { function buildTree(bytes32[] memory data) public pure returns (bytes32[] memory) { bytes32[] memory tree = initTree(data); + // for (uint256 i = tree.length - 1; i > 0; i -= 2) { + // uint256 posToWrite = (i - 1) / 2; + // tree[posToWrite] = hashLeafPairs(tree[i - 1], tree[i]); + // } assembly { function hash_leafs(left, right) -> _hash { switch lt(left, right) @@ -70,7 +72,7 @@ contract CompleteMerkle { } function getRoot(bytes32[] memory data) public pure returns (bytes32) { - require(data.length > 1, "won't generate root for single leaf"); + require(data.length > 1, "wont generate root for single leaf"); bytes32[] memory tree = buildTree(data); return tree[0]; } @@ -89,7 +91,7 @@ contract CompleteMerkle { } _hash := keccak256(0x0, 0x40) } - let roll := mload(0x40) + let roll := mload(0x40) // TODO CHECK Sv.s.M mstore(roll, valueToProve) let len := mload(proof) for { let i := 0 } lt(i, len) { i := add(i, 1) } { @@ -101,37 +103,30 @@ contract CompleteMerkle { } function getProof(bytes32[] memory data, uint256 index) public pure returns (bytes32[] memory) { - require(data.length > 1, "won't generate proof for single leaf"); + require(data.length > 1, "wont generate proof for single leaf"); bytes32[] memory tree = buildTree(data); - assembly { + assembly ("memory-safe") { let iter := sub(sub(mload(tree), index), 0x1) let ptr := mload(0x40) mstore(ptr, 0x20) let proofSizePtr := add(ptr, 0x20) - let proofIndexPtr := proofSizePtr - for {} eq(0x0, 0x0) {} { - // WHile true - switch eq(iter, 0) - case 1 { break } - mstore(proofSizePtr, add(mload(proofSizePtr), 0x1)) + let proofIndexPtr := add(ptr, 0x40) + for {} 0x1 {} { + // while (true) + let sibling := mload(add(tree, mul(add(iter, shl(0x1, and(iter, 0x1))), 0x20))) // TODO: can iter mul also be accomplised by shifting? is iter always going to be either 0 or 0x1 + // something like shr(0x20, mul(iter()) ehh prob not + mstore(proofIndexPtr, sibling) + //iter := div(add(sub(iter,1), and(iter,0x1)), 2) // 82 108 + //iter := div(sub(iter,eq(and(iter,0x1), 0x0)), 2) // 181 231 -- 187 201 -- 184 211 + iter := shr(1, sub(iter, eq(and(iter, 0x1), 0x0))) // 183 211 -- 182 241 -- 177 224 + // switch eq(iter, 0) + // case 1 { break } + if eq(iter, 0x0) { break } proofIndexPtr := add(proofIndexPtr, 0x20) - switch and(iter, 0x1) - case 0x1 { - // iter is ODD - let sibling := mload(add(add(tree, 0x20), mul(add(iter, 1), 0x20))) - mstore(proofIndexPtr, sibling) - iter := div(iter, 2) - } - default { - // iter is EVEN - let sibling := mload(add(add(tree, 0x20), mul(sub(iter, 1), 0x20))) - mstore(proofIndexPtr, sibling) - iter := div(sub(iter, 1), 2) - } } - mstore(0x40, add(proofIndexPtr, 0x20)) - return(ptr, add(0x20, sub(add(proofIndexPtr, 0x20), proofSizePtr))) + mstore(proofSizePtr, div(sub(proofIndexPtr, proofSizePtr), 0x20)) + return(ptr, add(0x40, sub(proofIndexPtr, proofSizePtr))) } } } diff --git a/src/test/CompleteMerkle.t.sol b/src/test/CompleteMerkle.t.sol deleted file mode 100644 index 6078b21..0000000 --- a/src/test/CompleteMerkle.t.sol +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.4; - -import "../CompleteMerkle.sol"; -import "../Merkle.sol"; -import "forge-std/Test.sol"; -import "openzeppelin-contracts/contracts/utils/cryptography/MerkleProof.sol"; -import "openzeppelin-contracts/contracts/utils/Strings.sol"; -import "forge-std/console.sol"; - -contract ContractTest is Test { - CompleteMerkle m; - Merkle om; - - function setUp() public { - m = new CompleteMerkle(); - om = new Merkle(); - } - - // function testHashes(bytes32 left, bytes32 right) public { - // bytes32 hAssem = m.hashLeafPairs(left, right); - // bytes memory packed; - // if (left <= right) { - // packed = abi.encodePacked(left, right); - // } else { - // if (right == bytes32(0x0)) { - // packed = abi.encodePacked(left, left); - // } else { - // packed = abi.encodePacked(right, left); - // } - // } - // bytes32 hNaive = keccak256(packed); - // assertEq(hAssem, hNaive); - // } - - // function testHashesDuplicateOddLeaf(bytes32 left) public { - // bytes32 right = bytes32(0x0); - // bytes32 hAssem = m.hashLeafPairs(left, right); - // bytes memory packed = abi.encodePacked(left, left); - // bytes32 hNaive = keccak256(packed); - // assertEq(hAssem, hNaive); - // } - - function testGenerateProof(bytes32[] memory data, uint256 node) public { - vm.assume(data.length > 1); - vm.assume(node < data.length); - bytes32 root = m.getRoot(data); - bytes32[] memory proof = m.getProof(data, node); - bytes32 valueToProve = data[node]; - - bytes32 rollingHash = valueToProve; - for (uint256 i = 0; i < proof.length; ++i) { - rollingHash = m.hashLeafPairs(rollingHash, proof[i]); - } - assertEq(rollingHash, root); - } - - function testVerifyProof(bytes32[] memory data, uint256 node) public { - vm.assume(data.length > 1); - vm.assume(node < data.length); - bytes32 root = m.getRoot(data); - bytes32[] memory proof = m.getProof(data, node); - bytes32 valueToProve = data[node]; - assertTrue(m.verifyProof(root, proof, valueToProve)); - } - - function testFailVerifyProof(bytes32[] memory data, bytes32 valueToProve, uint256 node) public { - vm.assume(data.length > 1); - vm.assume(node < data.length); - vm.assume(valueNotInArray(data, valueToProve)); - bytes32 root = m.getRoot(data); - bytes32[] memory proof = m.getProof(data, node); - assertTrue(m.verifyProof(root, proof, valueToProve)); - } - - function testWontGetRootSingleLeaf() public { - bytes32[] memory data = new bytes32[](1); - data[0] = bytes32(0x0); - vm.expectRevert("won't generate root for single leaf"); - m.getRoot(data); - } - - function testWontGetProofSingleLeaf() public { - bytes32[] memory data = new bytes32[](1); - data[0] = bytes32(0x0); - vm.expectRevert("won't generate proof for single leaf"); - m.getProof(data, 0x0); - } - - function valueNotInArray(bytes32[] memory data, bytes32 value) public pure returns (bool) { - for (uint256 i = 0; i < data.length; ++i) { - if (data[i] == value) return false; - } - return true; - } - - function testRootProofGenerationBasic6() public { - bytes32[] memory data = new bytes32[](6); - - data[0] = 0xcf0e8c2fa63ea2b3726dbea696df21baec00c4cdb37deeab03a15f190659544c; // sha3(alice, address(0x0)) - data[1] = 0xafeb0ccce3b008968e7ffbfc7482d85551f5f90e713b8441449d808d25e9cc64; // sha3(bob, address(0x0)) - data[2] = 0x7f37e8358c0d3959e3e800344c357551753c24a55d5d987cef461e933b137a02; // sha3(charlie, address(0x0)) - data[3] = 0xde1820ee7887b5ae922f14f423bb2e7a6595e423f1a0c0a82a2ddeed09a92a25; - data[4] = 0xcac6fc160d04af9e1fd8f0c71cf8d333453b39589d3846524462ee7737bd728d; - data[5] = 0x1688f29243f54ddded6dedcbbc8dae64ef939f0b967d0fa56a6e5938febb5f79; - bytes32 root = m.getRoot(data); - for (uint256 i = 0; i < data.length; ++i) { - bytes32[] memory proof = m.getProof(data, i); - assertTrue(m.verifyProof(root, proof, data[i])); - } - } - - function testRootGenerationOriginalMurky6ForGas() public { - bytes32[] memory data = new bytes32[](6); - - data[0] = 0xcf0e8c2fa63ea2b3726dbea696df21baec00c4cdb37deeab03a15f190659544c; // sha3(alice, address(0x0)) - data[1] = 0xafeb0ccce3b008968e7ffbfc7482d85551f5f90e713b8441449d808d25e9cc64; // sha3(bob, address(0x0)) - data[2] = 0x7f37e8358c0d3959e3e800344c357551753c24a55d5d987cef461e933b137a02; // sha3(charlie, address(0x0)) - data[3] = 0xde1820ee7887b5ae922f14f423bb2e7a6595e423f1a0c0a82a2ddeed09a92a25; - data[4] = 0xcac6fc160d04af9e1fd8f0c71cf8d333453b39589d3846524462ee7737bd728d; - data[5] = 0x1688f29243f54ddded6dedcbbc8dae64ef939f0b967d0fa56a6e5938febb5f79; - //bytes32[] memory tree = m._getTree(data); - bytes32 root = om.getRoot(data); - for (uint256 i = 0; i < data.length; ++i) { - bytes32[] memory proof = om.getProof(data, i); - assertTrue(om.verifyProof(root, proof, data[i])); - } - } -}