From 063c6d3760d0349128f98a8d857dc30e02998282 Mon Sep 17 00:00:00 2001 From: Lohann Paterno Coutinho Ferreira Date: Thu, 7 Nov 2024 09:55:39 +0000 Subject: [PATCH] Enumerate shards and routes (#26) --- lib/forge-std | 2 +- src/Gateway.sol | 146 +++++------------ src/NetworkID.sol | 87 ++++++---- src/interfaces/IExecutor.sol | 8 +- src/storage/Routes.sol | 261 ++++++++++++++++++++++++++++++ src/storage/Shards.sol | 297 +++++++++++++++++++++++++++++++++++ src/utils/EnumerableSet.sol | 237 ++++++++++++++++++++++++++++ src/utils/GasUtils.sol | 2 +- src/utils/Pointer.sol | 253 +++++++++++++++++++++++++++++ test/EnumerableSet.t.sol | 205 ++++++++++++++++++++++++ test/Gateway.t.sol | 6 +- 11 files changed, 1357 insertions(+), 147 deletions(-) create mode 100644 src/storage/Routes.sol create mode 100644 src/storage/Shards.sol create mode 100644 src/utils/EnumerableSet.sol create mode 100644 src/utils/Pointer.sol create mode 100644 test/EnumerableSet.t.sol diff --git a/lib/forge-std b/lib/forge-std index 1714bee..1eea5ba 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 1714bee72e286e73f76e320d110e0eaf5c4e649d +Subproject commit 1eea5bae12ae557d589f9f0f0edae2faa47cb262 diff --git a/src/Gateway.sol b/src/Gateway.sol index 1208140..df1b724 100644 --- a/src/Gateway.sol +++ b/src/Gateway.sol @@ -8,6 +8,7 @@ import {BranchlessMath} from "./utils/BranchlessMath.sol"; import {GasUtils} from "./utils/GasUtils.sol"; import {ERC1967} from "./utils/ERC1967.sol"; import {UFloat9x56, UFloatMath} from "./utils/Float9x56.sol"; +import {ShardStore} from "./storage/Shards.sol"; import {IGateway} from "./interfaces/IGateway.sol"; import {IUpgradable} from "./interfaces/IUpgradable.sol"; import {IGmpReceiver} from "./interfaces/IGmpReceiver.sol"; @@ -58,6 +59,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { using PrimitiveUtils for address; using BranchlessMath for uint256; using UFloatMath for UFloat9x56; + using ShardStore for ShardStore.MainStorage; uint8 internal constant SHARD_ACTIVE = (1 << 0); // Shard active bitflag uint8 internal constant SHARD_Y_PARITY = (1 << 1); // Pubkey y parity bitflag @@ -70,16 +72,9 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // Non-zero value used to initialize the `prevMessageHash` storage bytes32 internal constant FIRST_MESSAGE_PLACEHOLDER = bytes32(uint256(2 ** 256 - 1)); - // Shard data, maps the pubkey coordX (which is already collision resistant) to shard info. - mapping(bytes32 => KeyInfo) private _shards; - // GMP message status mapping(bytes32 => GmpInfo) private _messages; - // GAP necessary for migration purposes - mapping(GmpSender => mapping(uint16 => uint256)) private _deprecated_Deposits; - mapping(uint16 => bytes32) private _deprecated_Networks; - // Hash of the previous GMP message submitted. bytes32 public prevMessageHash; @@ -90,18 +85,6 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // Network ID => Source network mapping(uint16 => NetworkInfo) private _networkInfo; - /** - * @dev Shard info stored in the Gateway Contract - * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep - * the shard info below 256 bit, so it can be stored in one single storage slot. - * reference: https://docs.soliditylang.org/en/latest/internals/layout_in_storage.html - */ - struct KeyInfo { - uint216 _gap; // gap, so we can use later for store more information about a shard - uint8 status; // 0 = unregisted, 1 = active, 2 = revoked - uint32 nonce; // shard nonce - } - /** * @dev GMP info stored in the Gateway Contract * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep @@ -161,7 +144,8 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { _updateNetworks(networks); // Register keys - _registerKeys(keys); + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); + shards.registerTssKeys(keys); // emit event TssKey[] memory revoked = new TssKey[](0); @@ -172,8 +156,9 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { return _messages[id]; } - function keyInfo(bytes32 id) external view returns (KeyInfo memory) { - return _shards[id]; + function keyInfo(bytes32 id) external view returns (ShardStore.KeyInfo memory) { + ShardStore.MainStorage storage store = ShardStore.getMainStorage(); + return store.get(ShardStore.ShardID.wrap(id)); } function networkId() external view returns (uint16) { @@ -184,12 +169,20 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { return _networkInfo[id]; } + function listShards() external view returns (TssKey[] memory) { + return ShardStore.getMainStorage().listShards(); + } + /** * @dev Verify if shard exists, if the TSS signature is valid then increment shard's nonce. */ function _verifySignature(Signature calldata signature, bytes32 message) private view { // Load shard from storage - KeyInfo storage signer = _shards[bytes32(signature.xCoord)]; + ShardStore.KeyInfo storage signer; + { + ShardStore.MainStorage storage store = ShardStore.getMainStorage(); + signer = store.get(signature); + } // Verify if shard is active uint8 status = signer.status; @@ -206,11 +199,11 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { ); } - // Converts a `TssKey` into an `KeyInfo` unique identifier - function _tssKeyToShardId(TssKey memory tssKey) private pure returns (bytes32) { + // Converts a `TssKey` into an `ShardStore.ShardID` unique identifier + function _tssKeyToShardId(TssKey memory tssKey) private pure returns (ShardStore.ShardID) { // The tssKey coord x is already collision resistant // if we are unsure about it, we can hash the coord and parity bit - return bytes32(tssKey.xCoord); + return ShardStore.ShardID.wrap(bytes32(tssKey.xCoord)); } // Initialize networks @@ -227,87 +220,17 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } } - function _registerKeys(TssKey[] memory keys) private { - // We don't perform any arithmetic operation, except iterate a loop - unchecked { - // Register or activate tss key (revoked keys keep the previous nonce) - for (uint256 i = 0; i < keys.length; i++) { - TssKey memory newKey = keys[i]; - - // Read shard from storage - bytes32 shardId = _tssKeyToShardId(newKey); - KeyInfo storage shard = _shards[shardId]; - uint8 status = shard.status; - uint32 nonce = shard.nonce; - - // Check if the shard is not active - require((status & SHARD_ACTIVE) == 0, "already active, cannot register again"); - - // Check y-parity - uint8 yParity = newKey.yParity; - require(yParity == (yParity & 1), "y parity bit must be 0 or 1, cannot register shard"); - - // If nonce is zero, it's a new shard. - // If the shard exists, the provided y-parity must match the original one - uint8 actualYParity = uint8(BranchlessMath.toUint((status & SHARD_Y_PARITY) > 0)); - require( - nonce == 0 || actualYParity == yParity, - "the provided y-parity doesn't match the existing y-parity, cannot register shard" - ); - - // if is a new shard shard, set its initial nonce to 1 - shard.nonce = uint32(BranchlessMath.ternaryU32(nonce == 0, 1, nonce)); - - // enable/disable the y-parity flag - status = BranchlessMath.ternaryU8(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY); - - // enable SHARD_ACTIVE bitflag - status |= SHARD_ACTIVE; - - // Save new status in the storage - shard.status = status; - } - } - } - - // Revoke TSS keys - function _revokeKeys(TssKey[] memory keys) private { - // We don't perform any arithmetic operation, except iterate a loop - unchecked { - // Revoke tss keys - for (uint256 i = 0; i < keys.length; i++) { - TssKey memory revokedKey = keys[i]; - - // Read shard from storage - bytes32 shardId = _tssKeyToShardId(revokedKey); - KeyInfo storage shard = _shards[shardId]; - - // Check if the shard exists and is active - require(shard.nonce > 0, "shard doesn't exists, cannot revoke key"); - require((shard.status & SHARD_ACTIVE) > 0, "cannot revoke a shard key already revoked"); - - // Check y-parity - { - uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0; - require(yParity == revokedKey.yParity, "invalid y parity bit, cannot revoke key"); - } - - // Disable SHARD_ACTIVE bitflag - shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag - } - } - } - // Register/Revoke TSS keys and emits [`KeySetChanged`] event function _updateKeys(bytes32 messageHash, TssKey[] memory keysToRevoke, TssKey[] memory newKeys) private { - // We don't perform any arithmetic operation, except iterate a loop - unchecked { - // Revoke tss keys (revoked keys can be registred again keeping the previous nonce) - _revokeKeys(keysToRevoke); + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); - // Register or activate revoked keys - _registerKeys(newKeys); - } + // Revoke tss keys (revoked keys can be registred again keeping the previous nonce) + shards.revokeKeys(keysToRevoke); + + // Register or activate revoked keys + shards.registerTssKeys(newKeys); + + // Emit event emit KeySetChanged(messageHash, keysToRevoke, newKeys); } @@ -664,17 +587,18 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } // OBS: remove != revoke (when revoked, you cannot register again) - function sudoRemoveShards(TssKey[] memory shards) external payable { + function sudoRemoveShards(TssKey[] memory revokedKeys) external payable { require(msg.sender == _getAdmin(), "unauthorized"); - for (uint256 i; i < shards.length; i++) { - bytes32 shardId = _tssKeyToShardId(shards[i]); - delete _shards[shardId]; - } + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); + shards.revokeKeys(revokedKeys); + emit KeySetChanged(bytes32(0), revokedKeys, new TssKey[](0)); } - function sudoAddShards(TssKey[] memory shards) external payable { + function sudoAddShards(TssKey[] memory newKeys) external payable { require(msg.sender == _getAdmin(), "unauthorized"); - _registerKeys(shards); + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); + shards.registerTssKeys(newKeys); + emit KeySetChanged(bytes32(0), new TssKey[](0), newKeys); } // DANGER: This function is for migration purposes only, it allows the admin to set any storage slot. diff --git a/src/NetworkID.sol b/src/NetworkID.sol index bc9c890..b49d731 100644 --- a/src/NetworkID.sol +++ b/src/NetworkID.sol @@ -24,36 +24,63 @@ library NetworkIDHelpers { return NetworkID.unwrap(networkId); } - function chainId(NetworkID networkId) internal pure returns (uint64) { - uint256 id = NetworkID.unwrap(networkId); - uint256 chainid = type(uint256).max; - - // Ethereum Mainnet - chainid = BranchlessMath.ternary(id == asUint(MAINNET), 0, chainid); - // Astar - chainid = BranchlessMath.ternary(id == asUint(ASTAR), 592, chainid); - // Polygon PoS - chainid = BranchlessMath.ternary(id == asUint(POLYGON_POS), 137, chainid); - // Ethereum local testnet - chainid = BranchlessMath.ternary(id == asUint(ETHEREUM_LOCAL_DEV), 1337, chainid); - // Goerli - chainid = BranchlessMath.ternary(id == asUint(GOERLI), 5, chainid); - // Sepolia - chainid = BranchlessMath.ternary(id == asUint(SEPOLIA), 11155111, chainid); - // Astar local testnet - chainid = BranchlessMath.ternary(id == asUint(ASTAR_LOCAL_DEV), 592, chainid); - // Shibuya - chainid = BranchlessMath.ternary(id == asUint(SHIBUYA), 81, chainid); - // Polygon Amoy - chainid = BranchlessMath.ternary(id == asUint(POLYGON_AMOY), 80002, chainid); - // Binance Smart Chain - chainid = BranchlessMath.ternary(id == asUint(BINANCE_SMART_CHAIN_TESTNET), 97, chainid); - // Arbitrum Sepolia - chainid = BranchlessMath.ternary(id == asUint(ARBITRUM_SEPOLIA), 421614, chainid); - - require(chainid != type(uint256).max, "the provided network id doesn't exists"); - - return uint64(chainid); + /** + * @dev Get the EIP-150 chain id from the network id. + */ + function chainId(NetworkID networkId) internal pure returns (uint64 chainID) { + assembly { + switch networkId + case 0 { + // Ethereum Mainnet + chainID := 0 + } + case 1 { + // Astar + chainID := 592 + } + case 2 { + // Polygon PoS + chainID := 137 + } + case 3 { + // Ethereum local testnet + chainID := 1337 + } + case 4 { + // Goerli + chainID := 5 + } + case 5 { + // Sepolia + chainID := 11155111 + } + case 6 { + // Astar local testnet + chainID := 592 + } + case 7 { + // Shibuya + chainID := 81 + } + case 8 { + // Polygon Amoy + chainID := 80002 + } + case 9 { + // Binance Smart Chain + chainID := 97 + } + case 10 { + // Arbitrum Sepolia + chainID := 421614 + } + default { + // Unknown network id + chainID := 0xffffffffffffffff + } + } + require(chainID > 2 ** 24, "the provided network id doesn't exists"); + return uint64(chainID); } /** diff --git a/src/interfaces/IExecutor.sol b/src/interfaces/IExecutor.sol index 1bb3bd8..fe4d7c9 100644 --- a/src/interfaces/IExecutor.sol +++ b/src/interfaces/IExecutor.sol @@ -11,7 +11,8 @@ import { GmpStatus, UpdateKeysMessage, UpdateNetworkInfo, - GmpSender + GmpSender, + TssKey } from "../Primitives.sol"; /** @@ -38,6 +39,11 @@ interface IExecutor { */ event KeySetChanged(bytes32 indexed id, TssKey[] revoked, TssKey[] registered); + /** + * @dev List all shards currently registered in the gateway. + */ + function listShards() external returns (TssKey[] memory); + /** * Execute GMP message * @param signature Schnorr signature diff --git a/src/storage/Routes.sol b/src/storage/Routes.sol new file mode 100644 index 0000000..cd48497 --- /dev/null +++ b/src/storage/Routes.sol @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Analog's Contracts (last updated v0.1.0) (src/storage/Routes.sol) +pragma solidity ^0.8.20; + +import {UpdateNetworkInfo, Signature, Network} from "../Primitives.sol"; +import {NetworkIDHelpers, NetworkID} from "../NetworkID.sol"; +import {EnumerableSet, Pointer} from "../utils/EnumerableSet.sol"; +import {BranchlessMath} from "../utils/BranchlessMath.sol"; +import {UFloat9x56, UFloatMath} from "../utils/Float9x56.sol"; +import {StoragePtr} from "../utils/Pointer.sol"; + +/** + * @dev EIP-7201 Route's Storage + */ +library RouteStore { + using Pointer for StoragePtr; + using Pointer for uint256; + using EnumerableSet for EnumerableSet.Map; + using NetworkIDHelpers for NetworkID; + + /** + * @dev Namespace of the routes storage `analog.one.gateway.routes`. + * keccak256(abi.encode(uint256(keccak256("analog.one.gateway.routes")) - 1)) & ~bytes32(uint256(0xff)); + */ + bytes32 internal constant _EIP7201_NAMESPACE = 0xb184f2aad520cf7f1f1270909517c75ae33cdf2bd7d32b997a96577f11a48800; + + uint8 internal constant SHARD_ACTIVE = (1 << 0); // Shard active bitflag + uint8 internal constant SHARD_Y_PARITY = (1 << 1); // Pubkey y parity bitflag + + /** + * @dev Network info stored in the Gateway Contract + * @param domainSeparator Domain EIP-712 - Replay Protection Mechanism. + * @param gasLimit The maximum amount of gas we allow on this particular network. + * @param relativeGasPrice Gas price of destination chain, in terms of the source chain token. + * @param baseFee Base fee for cross-chain message approval on destination, in terms of source native gas token. + */ + struct NetworkInfo { + bytes32 domainSeparator; + uint64 gasLimit; + UFloat9x56 relativeGasPrice; + uint128 baseFee; + } + + /** + * @dev Network info stored in the Gateway Contract + * @param id Message unique id. + * @param networkId Network identifier. + * @param domainSeparator Domain EIP-712 - Replay Protection Mechanism. + * @param relativeGasPrice Gas price of destination chain, in terms of the source chain token. + * @param baseFee Base fee for cross-chain message approval on destination, in terms of source native gas token. + * @param gasLimit The maximum amount of gas we allow on this particular network. + */ + event NetworkUpdated( + bytes32 indexed id, + uint16 indexed networkId, + bytes32 indexed domainSeparator, + UFloat9x56 relativeGasPrice, + uint128 baseFee, + uint64 gasLimit + ); + + /** + * @dev Shard info stored in the Gateway Contract + * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep + * the shard info below 256 bit, so it can be stored in one single storage slot. + * reference: https://docs.soliditylang.org/en/latest/internals/layout_in_storage.html + * + * @custom:storage-location erc7201:analog.one.gateway.routes + */ + struct MainStorage { + EnumerableSet.Map routes; + } + + error ShardAlreadyRegistered(NetworkID id); + error ShardNotExists(NetworkID id); + error IndexOutOfBounds(uint256 index); + + function getMainStorage() internal pure returns (MainStorage storage $) { + assembly { + $.slot := _EIP7201_NAMESPACE + } + } + + function asPtr(NetworkInfo storage keyInfo) internal pure returns (StoragePtr ptr) { + assembly { + ptr := keyInfo.slot + } + } + + function _ptrToRoute(StoragePtr ptr) private pure returns (NetworkInfo storage route) { + assembly { + route.slot := ptr + } + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(MainStorage storage store, NetworkInfo storage keyInfo) internal view returns (bool) { + return store.routes.contains(asPtr(keyInfo)); + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function has(MainStorage storage store, NetworkID id) internal view returns (bool) { + return store.routes.has(bytes32(uint256(id.asUint()))); + } + + /** + * @dev Get or create a value. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function getOrAdd(MainStorage storage store, NetworkID id) private returns (bool, NetworkInfo storage) { + (bool success, StoragePtr ptr) = store.routes.tryAdd(bytes32(uint256(id.asUint()))); + return (success, _ptrToRoute(ptr)); + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(MainStorage storage store, NetworkID id) internal returns (bool) { + StoragePtr ptr = store.routes.remove(bytes32(uint256(id.asUint()))); + if (ptr.isNull()) { + return false; + } + return true; + } + + /** + * @dev Returns the number of values on the set. O(1). + */ + function length(MainStorage storage store) internal view returns (uint256) { + return store.routes.length(); + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(MainStorage storage store, uint256 index) internal view returns (NetworkInfo storage) { + StoragePtr ptr = store.routes.at(index); + if (ptr.isNull()) { + revert IndexOutOfBounds(index); + } + return _ptrToRoute(ptr); + } + + /** + * @dev Returns the value associated with `NetworkInfo`. O(1). + * + * Requirements: + * - `NetworkInfo` must be in the map. + */ + function get(MainStorage storage store, NetworkID id) internal view returns (NetworkInfo storage) { + StoragePtr ptr = store.routes.get(bytes32(uint256(id.asUint()))); + if (ptr.isNull()) { + revert ShardNotExists(id); + } + return _ptrToRoute(ptr); + } + + /** + * @dev Returns the value associated with `NetworkInfo`. O(1). + */ + function tryGet(MainStorage storage store, NetworkID id) internal view returns (bool, NetworkInfo storage) { + (bool exists, StoragePtr ptr) = store.routes.tryGet(bytes32(uint256(id.asUint()))); + return (exists, _ptrToRoute(ptr)); + } + + function createOrUpdateNetworkInfo(MainStorage storage store, bytes32 messageHash, UpdateNetworkInfo calldata info) + private + { + require(info.mortality >= block.number, "message expired"); + + // Verify signature and if the message was already executed + // require(_executedMessages[messageHash] == bytes32(0), "message already executed"); + + // Update network info + (bool created, NetworkInfo storage stored) = getOrAdd(store, NetworkID.wrap(info.networkId)); + require(!created || info.domainSeparator != bytes32(0), "domain separator cannot be zero"); + + // Verify and update domain separator if it's not zero + if (info.domainSeparator != bytes32(0)) { + stored.domainSeparator = info.domainSeparator; + } + + // Update gas limit if it's not zero + if (info.gasLimit > 0) { + stored.gasLimit = info.gasLimit; + } + + // Update relative gas price and base fee if any of them are greater than zero + if (UFloat9x56.unwrap(info.relativeGasPrice) > 0 || info.baseFee > 0) { + stored.relativeGasPrice = info.relativeGasPrice; + stored.baseFee = info.baseFee; + } + + // Save the message hash to prevent replay attacks + // _executedMessages[messageHash] = executor; + + // Update network info + // _networkInfo[info.networkId] = stored; + + emit NetworkUpdated( + messageHash, + info.networkId, + stored.domainSeparator, + stored.relativeGasPrice, + stored.baseFee, + stored.gasLimit + ); + } + + // Initialize networks + function initialize( + MainStorage storage store, + Network[] calldata networks, + NetworkID networkdID, + function(Network calldata) internal pure returns(bytes32) computeDomainSeparator + ) private { + for (uint256 i = 0; i < networks.length; i++) { + Network calldata network = networks[i]; + (bool exists, NetworkInfo storage info) = tryGet(store, NetworkID.wrap(network.id)); + require(!exists && info.domainSeparator == bytes32(0), "network already initialized"); + require(network.id != networkdID.asUint() || network.gateway == address(this), "wrong gateway address"); + info.domainSeparator = computeDomainSeparator(network); + info.gasLimit = 15_000_000; // Default to 15M gas + info.relativeGasPrice = UFloatMath.ONE; + info.baseFee = 0; + } + } + + /** + * @dev Return all routes registered currently registered. + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function listRoutes(MainStorage storage store) internal view returns (NetworkInfo[] memory) { + bytes32[] memory idx = store.routes.keys; + NetworkInfo[] memory routes = new NetworkInfo[](idx.length); + for (uint256 i = 0; i < idx.length; i++) { + routes[i] = _ptrToRoute(store.routes.values[idx[i]]); + } + return routes; + } +} diff --git a/src/storage/Shards.sol b/src/storage/Shards.sol new file mode 100644 index 0000000..30dddcc --- /dev/null +++ b/src/storage/Shards.sol @@ -0,0 +1,297 @@ +// SPDX-License-Identifier: MIT +// Analog's Contracts (last updated v0.1.0) (src/storage/Shards.sol) +pragma solidity ^0.8.20; + +import {TssKey, Signature} from "../Primitives.sol"; +import {EnumerableSet, Pointer} from "../utils/EnumerableSet.sol"; +import {BranchlessMath} from "../utils/BranchlessMath.sol"; +import {StoragePtr} from "../utils/Pointer.sol"; + +/** + * @dev EIP-7201 Shard's Storage + */ +library ShardStore { + using Pointer for StoragePtr; + using Pointer for uint256; + using EnumerableSet for EnumerableSet.Map; + + /** + * @dev Namespace of the shards storage `analog.one.gateway.shards`. + * keccak256(abi.encode(uint256(keccak256("analog.one.gateway.shards")) - 1)) & ~bytes32(uint256(0xff)); + */ + bytes32 internal constant _EIP7201_NAMESPACE = 0x582bcdebbeef4fb96dde802cfe96e9942657f4bedb5cfe94e8786bb683eb1f00; + + uint8 internal constant SHARD_ACTIVE = (1 << 0); // Shard active bitflag + uint8 internal constant SHARD_Y_PARITY = (1 << 1); // Pubkey y parity bitflag + + /** + * @dev Shard ID, this is the xCoord of the TssKey + */ + type ShardID is bytes32; + + /** + * @dev Current status of the shard + */ + enum ShardStatus { + Unregistered, + Active, + Revoked + } + + /** + * @dev Shard info stored in the Gateway Contract + * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep + * the shard info below 256 bit, so it can be stored in one single storage slot. + * reference: https://docs.soliditylang.org/en/latest/internals/layout_in_storage.html + * + * @custom:storage-location erc7201:analog.one.gateway.shards + */ + struct KeyInfo { + uint8 status; + uint32 nonce; + } + + /** + * @dev Shard info stored in the Gateway Contract + * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep + * the shard info below 256 bit, so it can be stored in one single storage slot. + * reference: https://docs.soliditylang.org/en/latest/internals/layout_in_storage.html + * + * @custom:storage-location erc7201:analog.one.gateway.shards + */ + struct MainStorage { + EnumerableSet.Map shards; + } + + error ShardAlreadyRegistered(ShardID id); + error ShardNotExists(ShardID id); + error IndexOutOfBounds(uint256 index); + + function getMainStorage() internal pure returns (MainStorage storage $) { + assembly { + $.slot := _EIP7201_NAMESPACE + } + } + + function asPtr(KeyInfo storage keyInfo) internal pure returns (StoragePtr ptr) { + assembly { + ptr := keyInfo.slot + } + } + + function _getKeyInfo(StoragePtr ptr) private pure returns (KeyInfo storage keyInfo) { + assembly { + keyInfo.slot := ptr + } + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(MainStorage storage store, KeyInfo storage keyInfo) internal view returns (bool) { + return store.shards.contains(asPtr(keyInfo)); + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function has(MainStorage storage store, ShardID id) internal view returns (bool) { + return store.shards.has(ShardID.unwrap(id)); + } + + /** + * @dev Get or create a value. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function getOrAdd(MainStorage storage store, ShardID xCoord) private returns (bool, KeyInfo storage) { + (bool success, StoragePtr ptr) = store.shards.tryAdd(ShardID.unwrap(xCoord)); + return (success, _getKeyInfo(ptr)); + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(MainStorage storage store, ShardID id) internal returns (bool) { + StoragePtr ptr = store.shards.remove(ShardID.unwrap(id)); + if (ptr.isNull()) { + return false; + } + KeyInfo storage keyInfo = _getKeyInfo(ptr); + keyInfo.status &= ~SHARD_ACTIVE; + return true; + } + + /** + * @dev Returns the number of values on the set. O(1). + */ + function length(MainStorage storage store) internal view returns (uint256) { + return store.shards.length(); + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(MainStorage storage store, uint256 index) internal view returns (KeyInfo storage) { + StoragePtr ptr = store.shards.at(index); + if (ptr.isNull()) { + revert IndexOutOfBounds(index); + } + return _getKeyInfo(ptr); + } + + /** + * @dev Returns the value associated with `key`. O(1). + * + * Requirements: + * - `key` must be in the map. + */ + function get(MainStorage storage store, ShardID key) internal view returns (KeyInfo storage) { + StoragePtr ptr = store.shards.get(ShardID.unwrap(key)); + if (ptr.isNull()) { + revert ShardNotExists(key); + } + return _getKeyInfo(ptr); + } + + /** + * @dev Returns the `KeyInfo` associated with `TssKey`. O(1). + * + * Requirements: + * - `key.xCoord` must be in the map. + */ + function get(MainStorage storage store, TssKey calldata key) internal view returns (KeyInfo storage) { + return get(store, ShardID.wrap(bytes32(key.xCoord))); + } + + /** + * @dev Returns the `KeyInfo` associated with `Signature`. O(1). + * + * Requirements: + * - `signature.xCoord` must be in the map. + */ + function get(MainStorage storage store, Signature calldata signature) internal view returns (KeyInfo storage) { + return get(store, ShardID.wrap(bytes32(signature.xCoord))); + } + + /** + * @dev Returns the value associated with `key`. O(1). + */ + function tryGet(MainStorage storage store, ShardID key) internal view returns (bool, KeyInfo storage) { + (bool exists, StoragePtr ptr) = store.shards.tryGet(ShardID.unwrap(key)); + return (exists, _getKeyInfo(ptr)); + } + + /** + * @dev Register TSS keys. + * Requirements: + * - The `keys` should not be already registered. + */ + function registerTssKeys(MainStorage storage store, TssKey[] memory keys) internal { + // We don't perform any arithmetic operation, except iterate a loop + unchecked { + // Register or activate tss key (revoked keys keep the previous nonce) + for (uint256 i = 0; i < keys.length; i++) { + TssKey memory newKey = keys[i]; + uint8 yParity = newKey.yParity; + require(yParity == (yParity & 1), "y parity bit must be 0 or 1, cannot register shard"); + + // Read shard from storage + ShardID id = ShardID.wrap(bytes32(newKey.xCoord)); + (bool success, KeyInfo storage shard) = getOrAdd(store, id); + + // Check if the shard is already registered + if (!success) { + revert ShardAlreadyRegistered(id); + } + + uint32 nonce = shard.nonce; + uint8 status = shard.status; + { + uint8 actualYParity = uint8(BranchlessMath.toUint((status & SHARD_Y_PARITY) > 0)); + require( + nonce == 0 || actualYParity == yParity, + "the provided y-parity doesn't match the existing y-parity, cannot register shard" + ); + nonce += uint32(BranchlessMath.toUint(nonce == 0)); + } + + // enable/disable the y-parity flag + status = BranchlessMath.ternaryU8(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY); + status |= SHARD_ACTIVE; + + // Save new status and nonce in the storage + shard.status = status; + shard.nonce = nonce; + } + } + } + + /** + * @dev Register TSS keys. + * Requirements: + * - The `keys` must be registered. + */ + function revokeKeys(MainStorage storage store, TssKey[] memory keys) internal { + // We don't perform any arithmetic operation, except iterate a loop + unchecked { + // Revoke tss keys + for (uint256 i = 0; i < keys.length; i++) { + TssKey memory revokedKey = keys[i]; + + // Read shard from storage + ShardID id = ShardID.wrap(bytes32(revokedKey.xCoord)); + KeyInfo storage shard; + { + bool shardExists; + (shardExists, shard) = tryGet(store, id); + + if (!shardExists || shard.nonce == 0) { + revert ShardNotExists(id); + } + } + + // Check y-parity + { + uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0; + require(yParity == revokedKey.yParity, "y parity bit mismatch, cannot revoke key"); + } + + // Disable SHARD_ACTIVE bitflag + shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag + + // Remove from the set + store.shards.remove(ShardID.unwrap(id)); + } + } + } + + /** + * @dev Return all shards registered currently registered. + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function listShards(MainStorage storage store) internal view returns (TssKey[] memory) { + bytes32[] memory idx = store.shards.keys; + TssKey[] memory keys = new TssKey[](idx.length); + for (uint256 i = 0; i < idx.length; i++) { + KeyInfo storage keyInfo = _getKeyInfo(store.shards.values[idx[i]]); + keys[i] = TssKey(keyInfo.status & SHARD_Y_PARITY, uint256(idx[i])); + } + return keys; + } +} diff --git a/src/utils/EnumerableSet.sol b/src/utils/EnumerableSet.sol new file mode 100644 index 0000000..cff2a74 --- /dev/null +++ b/src/utils/EnumerableSet.sol @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Analog's Contracts (last updated v0.1.0) (src/utils/EnumerableMap.sol) +pragma solidity ^0.8.20; + +import {StoragePtr, Pointer} from "./Pointer.sol"; + +/** + * @dev Library for managing an enumerable variant of Solidity's + * https://solidity.readthedocs.io/en/latest/types.html#mapping-types[`mapping`] + * type. + */ +library EnumerableSet { + using Pointer for StoragePtr; + + error ValueAlreadyPresent(bytes32); + + /** + * @dev Shard info stored in the Gateway Contract + * OBS: the order of the attributes matters! ethereum storage is 256bit aligned, try to keep + * the shard info below 256 bit, so it can be stored in one single storage slot. + * reference: https://docs.soliditylang.org/en/latest/internals/layout_in_storage.html + * + * @custom:storage-location erc7201:analog.one.gateway.shards + */ + struct Map { + bytes32[] keys; + mapping(bytes32 => StoragePtr) values; + } + + /** + * @dev Returns index of a given value in the set. O(1). + * + * Returns -1 if the value is not in the set. + */ + function indexOf(Map storage map, StoragePtr ptr) internal view returns (int256 index) { + assembly ("memory-safe") { + index := not(sload(sub(ptr, 1))) + mstore(0x00, map.slot) + mstore(0x00, sload(add(keccak256(0x00, 0x20), index))) + mstore(0x20, add(map.slot, 1)) + index := or(index, sub(and(eq(ptr, keccak256(0x00, 0x40)), lt(index, sload(map.slot))), 1)) + } + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(Map storage map, StoragePtr value) internal view returns (bool r) { + return indexOf(map, value) >= 0; + } + + /** + * @dev Returns true if the key is in the set. O(1). + */ + function has(Map storage map, bytes32 key) internal view returns (bool r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + r := gt(sload(sub(r, 1)), 0) + } + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(Map storage map, bytes32 key) internal returns (StoragePtr r) { + bool success; + (success, r) = tryAdd(map, key); + if (!success) { + revert ValueAlreadyPresent(key); + } + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function tryAdd(Map storage map, bytes32 key) internal returns (bool success, StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + success := iszero(sload(sub(r, 1))) + if success { + // Load the array size + let size := sload(map.slot) + + // Store the value + mstore(0x00, map.slot) + sstore(add(keccak256(0x00, 0x20), size), key) + + // Update the value's index + sstore(sub(r, 1), not(size)) + + // Update array size + size := add(size, 1) + sstore(map.slot, size) + } + } + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns the removed value storage pointer, if it was present, or null if it was not. + */ + function remove(Map storage map, bytes32 key) internal returns (StoragePtr r) { + assembly ("memory-safe") { + // Find the value's index + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + let index := not(sload(sub(r, 1))) + + // First element storage index + let keys_count := sload(map.slot) + mstore(0x00, map.slot) + let keys_start := keccak256(0x00, 0x20) + let val_key_ptr := add(keys_start, index) + + // (index < map.keys.length) && key == map.keys[map.values[key].index] + r := mul(r, and(lt(index, keys_count), eq(key, sload(val_key_ptr)))) + + if r { + // (index + 1) < map.keys.length + if lt(add(index, 1), keys_count) { + // Move the last element to the removed element's position + let last_index := sub(keys_count, 1) + let last_key := sload(add(keys_start, last_index)) + sstore(val_key_ptr, last_key) + + // Update the last element's index + mstore(0x00, last_key) + mstore(0x20, add(map.slot, 1)) + sstore(sub(keccak256(0x00, 0x40), 1), not(index)) + } + + // Update array size + sstore(map.slot, sub(keys_count, 1)) + + // Remove index + sstore(sub(r, 1), 0) + } + } + } + + /** + * @dev Returns the number of values on the set. O(1). + */ + function length(Map storage map) internal view returns (uint256) { + return map.keys.length; + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(Map storage map, uint256 index) internal view returns (StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, map.slot) + let key := sload(add(keccak256(0x00, 0x20), index)) + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + key := not(sload(sub(r, 1))) + r := mul(r, and(lt(index, sload(map.slot)), eq(index, key))) + } + } + + /** + * @dev Returns the value associated with `key`. O(1). + * + * Requirements: + * - `key` must be in the map. + */ + function get(Map storage map, bytes32 key) internal view returns (StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + r := mul(r, gt(sload(sub(r, 1)), 0)) + } + } + + /** + * @dev Tries to returns the value associated with `key`. O(1). + * Does not revert if `key` is not in the map. + */ + function tryGet(Map storage map, bytes32 key) internal view returns (bool exists, StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + exists := gt(sload(sub(r, 1)), 0) + } + } + + /** + * @dev Returns the value associated with `key`. O(1). + */ + function getUnchecked(Map storage map, bytes32 key) internal pure returns (StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + } + } + + // /** + // * @dev Return the entire set in an array + // * + // * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + // * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + // * this function has an unbounded cost, and using it as part of a state-changing function may render the function + // * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + // */ + // function _values(EnumerableMap storage m) private view returns (KeyInfo[] memory) { + // ShardID[] memory keys = s.keys; + // KeyInfo[] memory values = new KeyInfo[](keys.length); + // for (uint256 i = 0; i < keys.length; i++) { + // values[i] = s.shards[keys[i]]; + // } + // return values; + // } +} diff --git a/src/utils/GasUtils.sol b/src/utils/GasUtils.sol index ed270bb..0d5f12c 100644 --- a/src/utils/GasUtils.sol +++ b/src/utils/GasUtils.sol @@ -13,7 +13,7 @@ library GasUtils { /** * @dev Base cost of the `IExecutor.execute` method. */ - uint256 internal constant EXECUTION_BASE_COST = 44469; + uint256 internal constant EXECUTION_BASE_COST = 44469 + 2245 - 11; /** * @dev Base cost of the `IGateway.submitMessage` method. diff --git a/src/utils/Pointer.sol b/src/utils/Pointer.sol new file mode 100644 index 0000000..ef92f9a --- /dev/null +++ b/src/utils/Pointer.sol @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Analog's Contracts (last updated v0.1.0) (src/utils/StoragePtr.sol) +pragma solidity ^0.8.20; + +/** + * @dev Represents a raw pointer to a value in storage. + */ +type StoragePtr is uint256; + +/** + * @dev Library for reading and writing primitive types to specific storage slots. + * + * Storage slots are often used to avoid storage conflict when dealing with upgradeable contracts. + * This library helps with reading and writing to such slots without the need for inline assembly. + * + * The functions in this library return Slot structs that contain a `value` member that can be used to read or write. + * + * Example usage to set ERC-1967 implementation slot: + * ```solidity + * contract ERC1967 { + * // Define the slot. Alternatively, use the SlotDerivation library to derive the slot. + * bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + * + * function _getImplementation() internal view returns (address) { + * return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + * } + * + * function _setImplementation(address newImplementation) internal { + * require(newImplementation.code.length > 0); + * StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + * } + * } + * ``` + * + * TIP: Consider using this library along with {SlotDerivation}. + */ +library Pointer { + struct AddressSlot { + address value; + } + + struct BooleanSlot { + bool value; + } + + struct Bytes32Slot { + bytes32 value; + } + + struct Uint256Slot { + uint256 value; + } + + struct Int256Slot { + int256 value; + } + + struct StringSlot { + string value; + } + + struct BytesSlot { + bytes value; + } + + /** + * @dev Converts `uint256[] storage` to `StoragePtr`. + */ + function asPtr(uint256[] storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Converts `bytes32[] storage` to `StoragePtr`. + */ + function asPtr(bytes32[] storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Converts `bytes storage` to `StoragePtr`. + */ + function asPtr(bytes storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Wraps a value in a `StoragePtr`. + */ + function asPtr(uint256 value) internal pure returns (StoragePtr) { + return StoragePtr.wrap(value); + } + + /** + * @dev Unwraps a `StoragePtr` to a value. + */ + function asPtr(bytes32 value) internal pure returns (StoragePtr) { + return StoragePtr.wrap(uint256(value)); + } + + /** + * @dev Convert a `StoragePtr` to `uint256`. + */ + function asUint(StoragePtr ptr) internal pure returns (uint256) { + return StoragePtr.unwrap(ptr); + } + + /** + * @dev Convert a `StoragePtr` to `int256`. + */ + function asInt(StoragePtr ptr) internal pure returns (int256) { + return int256(StoragePtr.unwrap(ptr)); + } + + /** + * @dev Convert a `StoragePtr` to `bytes32`. + */ + function asBytes32(StoragePtr ptr) internal pure returns (bytes32) { + return bytes32(StoragePtr.unwrap(ptr)); + } + + /** + * @dev Whether the `StoragePtr` is zero or not. + */ + function isNull(StoragePtr ptr) internal pure returns (bool r) { + assembly ("memory-safe") { + r := iszero(ptr) + } + } + + /** + * @dev Returns an `AddressSlot` with member `value` located at `slot`. + */ + function getAddressSlot(StoragePtr slot) internal pure returns (AddressSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `AddressSlot` into an `StoragePtr`. + */ + function asPtr(AddressSlot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Returns a `BooleanSlot` with member `value` located at `slot`. + */ + function getBooleanSlot(StoragePtr slot) internal pure returns (BooleanSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `BooleanSlot` into an `StoragePtr`. + */ + function asPtr(BooleanSlot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Returns a `Bytes32Slot` with member `value` located at `slot`. + */ + function getBytes32Slot(StoragePtr slot) internal pure returns (Bytes32Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `Bytes32Slot` into an `StoragePtr`. + */ + function asPtr(Bytes32Slot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Returns a `Uint256Slot` with member `value` located at `slot`. + */ + function getUint256Slot(StoragePtr slot) internal pure returns (Uint256Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `Uint256Slot` into an `StoragePtr`. + */ + function asPtr(Uint256Slot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Returns a `Int256Slot` with member `value` located at `slot`. + */ + function getInt256Slot(StoragePtr slot) internal pure returns (Int256Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `Int256Slot` into an `StoragePtr`. + */ + function asPtr(Int256Slot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } + + /** + * @dev Returns an `StringSlot` representation of the string storage pointer `store`. + */ + function getStringSlot(StoragePtr slot) internal pure returns (StringSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `BytesSlot` with member `value` located at `slot`. + */ + function getBytesSlot(StoragePtr slot) internal pure returns (BytesSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Converts a `BytesSlot` into an `StoragePtr`. + */ + function asPtr(BytesSlot storage store) internal pure returns (StoragePtr ptr) { + assembly ("memory-safe") { + ptr := store.slot + } + } +} diff --git a/test/EnumerableSet.t.sol b/test/EnumerableSet.t.sol new file mode 100644 index 0000000..3ad112f --- /dev/null +++ b/test/EnumerableSet.t.sol @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Analog's Contracts (last updated v0.1.0) (test/EnumerableSet.t.sol) + +pragma solidity >=0.8.0; + +import {Test, console} from "forge-std/Test.sol"; +import {TestUtils} from "./TestUtils.sol"; +import {EnumerableSet, StoragePtr} from "../src/utils/EnumerableSet.sol"; +import {StoragePtr, Pointer} from "../src/utils/Pointer.sol"; +import {BranchlessMath} from "../src/utils/BranchlessMath.sol"; + +contract EnumerableSetTest is Test { + using BranchlessMath for uint256; + using EnumerableSet for EnumerableSet.Map; + using Pointer for StoragePtr; + using Pointer for uint256; + using Pointer for Pointer.Uint256Slot; + + uint256 private constant ITERATIONS = 10; + + EnumerableSet.Map private map; + + struct MyStruct { + uint256 a; + uint256 b; + uint256 c; + } + + function _add(uint256 key, uint256 value, bool success) private returns (MyStruct storage r) { + bytes32 ptr = map.add(bytes32(key)).asBytes32(); + if (success) { + assertNotEq(ptr, bytes32(0), "map.add failed"); + assembly { + r.slot := ptr + } + r.a = value; + r.b = value + 1; + r.c = value + 2; + } else { + assertEq(ptr, bytes32(0), "expect map.add to fail"); + assembly { + r.slot := 0 + } + } + } + + function _at(uint256 index, bool success) private view returns (MyStruct storage r) { + bytes32 ptr = map.at(index).asBytes32(); + if (success) { + assertNotEq(ptr, bytes32(0), "map.at failed"); + assembly { + r.slot := ptr + } + } else { + assertEq(ptr, bytes32(0), "expect map.at to fail"); + assembly { + r.slot := 0 + } + } + } + + function _get(uint256 key, bool success) private view returns (MyStruct storage r) { + bytes32 ptr = map.get(bytes32(key)).asBytes32(); + if (success) { + assertNotEq(ptr, bytes32(0), "map.at failed"); + assembly { + r.slot := ptr + } + } else { + assertEq(ptr, bytes32(0), "expect map.at to fail"); + assembly { + r.slot := 0 + } + } + } + + function _removeByKey(uint256 key, bool success) private returns (MyStruct storage r) { + bytes32 ptr = map.remove(bytes32(key)).asBytes32(); + if (success) { + assertNotEq(ptr, bytes32(0), "map.at failed"); + assembly { + r.slot := ptr + } + } else { + assertEq(ptr, bytes32(0), "expect map.at to fail"); + assembly { + r.slot := 0 + } + } + } + + /** + * Test if `Map.add` and `Map.at` work as expected. + */ + function test_add() external { + assertEq(map.length(), 0, "Map should be empty"); + + MyStruct storage s; + for (uint256 i = 0; i < ITERATIONS; i++) { + s = _add(0x1234 + i, i + 1, true); + assertEq(map.length(), i + 1); + for (uint256 j = 0; j < ITERATIONS; j++) { + s = _at(j, j <= i); + if (j <= i) { + assertEq(s.a, j + 1, "MyStruct.a mismatch"); + assertEq(s.b, j + 2, "MyStruct.b mismatch"); + assertEq(s.c, j + 3, "MyStruct.c mismatch"); + } + } + } + } + + /** + * Test if `Map.add` and `Map.at` work as expected. + */ + function test_remove() external { + assertEq(map.length(), 0, "Map should be empty"); + + MyStruct storage s; + for (uint256 i = 0; i < ITERATIONS; i++) { + s = _add(0xdeadbeef + i, i + 1, true); + } + + assertEq(map.length(), ITERATIONS, "unexpected map length"); + uint256 count = ITERATIONS - 1; + _removeByKey(0xdeadbeef + count, true); + assertEq(map.length(), count, "element not removed"); + + // Cannot remove the same key twice + _removeByKey(0xdeadbeef + count, false); + + // Cannot remove an unknown key + _removeByKey(0xdeadbeef + ITERATIONS * 2, false); + + for (uint256 i = 0; i < ITERATIONS; i++) { + s = _at(i, i < count); + if (i < count) { + assertEq(s.a, i + 1, "MyStruct.a mismatch"); + assertEq(s.b, i + 2, "MyStruct.b mismatch"); + assertEq(s.c, i + 3, "MyStruct.c mismatch"); + } + } + + uint256 removeIndex = count - 3; + _removeByKey(0xdeadbeef + removeIndex, true); + count -= 1; + assertEq(map.length(), count, "element not removed"); + s = _at(removeIndex, true); + assertEq(s.a, count + 1, "MyStruct.a mismatch"); + assertEq(s.b, count + 2, "MyStruct.b mismatch"); + assertEq(s.c, count + 3, "MyStruct.c mismatch"); + + s = _at(removeIndex + 1, true); + assertEq(s.a, removeIndex + 2, "MyStruct.a mismatch"); + assertEq(s.b, removeIndex + 3, "MyStruct.b mismatch"); + assertEq(s.c, removeIndex + 4, "MyStruct.c mismatch"); + + for (uint256 i = 0; i < map.keys.length; i++) { + uint256 key = uint256(map.keys[i]); + assertEq(map.values[bytes32(key)].asUint(), key - 0xdeadbeef + 1, "MyStruct.a mismatch"); + s = _get(key, true); + assertEq(s.a, key - 0xdeadbeef + 1, "MyStruct.a mismatch"); + assertEq(s.b, key - 0xdeadbeef + 2, "MyStruct.b mismatch"); + assertEq(s.c, key - 0xdeadbeef + 3, "MyStruct.c mismatch"); + } + } + + /** + * Test if `Map.add` and `Map.at` work as expected. + */ + function test_fuzz(bytes32 key, uint256 value) external { + assertEq(map.length(), 0, "Map should be empty"); + + // Map.length works + Pointer.Uint256Slot storage store; + store = map.add(key).getUint256Slot(); + assertFalse(store.asPtr().isNull(), "invalid pointer"); + store.value = value; + + // Map.length works + assertEq(map.length(), 1, "unexpected map length"); + + // Map.get works + store = map.get(key).getUint256Slot(); + assertEq(store.value, value, "unexpected value when retrieving by key"); + + // Map.indexOf works + int256 index = map.indexOf(store.asPtr()); + assertEq(index, 0, "unexpected index"); + + // Map.at works + store = map.at(0).getUint256Slot(); + assertEq(store.value, value, "unexpected value when retrieving by index"); + + // Map.contains works + StoragePtr ptr = map.get(key); + assertTrue(map.contains(ptr), "the key should be in the map"); + + // Map.contains returns false for invalid pointers + ptr = (ptr.asUint() + 1).asPtr(); + assertFalse(map.contains(ptr), "invalid pointer"); + ptr = (ptr.asUint() - 2).asPtr(); + assertFalse(map.contains(ptr), "invalid pointer"); + } +} diff --git a/test/Gateway.t.sol b/test/Gateway.t.sol index 5c7e69e..79eb0e5 100644 --- a/test/Gateway.t.sol +++ b/test/Gateway.t.sol @@ -216,7 +216,7 @@ contract GatewayBase is Test { function test_estimateMessageCost() external { vm.txGasPrice(1); uint256 cost = gateway.estimateMessageCost(DEST_NETWORK_ID, 96, 100000); - assertEq(cost, 178501); + assertEq(cost, GasUtils.EXECUTION_BASE_COST + 134032); } function test_checkPayloadSize() external { @@ -394,7 +394,7 @@ contract GatewayBase is Test { ctx.value = GasUtils.estimateGas(uint16(nonZeros), uint16(zeros), gmp.gasLimit); } - uint256 snapshot = vm.snapshot(); + uint256 snapshot = vm.snapshotState(); // Must work if the funds and gas limit are sufficient bytes32 id = gmp.eip712TypedHash(_dstDomainSeparator); vm.expectEmit(true, true, true, true); @@ -411,7 +411,7 @@ contract GatewayBase is Test { ); // Must revert if fund are insufficient - vm.revertTo(snapshot); + vm.revertToState(snapshot); ctx.value -= 1; vm.expectRevert("insufficient tx value"); ctx.submitMessage(gmp);