From b80ab5fb3de8e4d315a0e22d04c9f1df437f5e74 Mon Sep 17 00:00:00 2001 From: Lohann Paterno Coutinho Ferreira Date: Mon, 4 Nov 2024 15:11:14 -0300 Subject: [PATCH] Move shard logic to shard storage --- src/Gateway.sol | 81 +++++++++++-------------------- src/storage/Shards.sol | 95 ++++++++++++++++++++++++++++++++----- src/utils/EnumerableSet.sol | 8 ++++ test/EnumerableSet.t.sol | 2 +- 4 files changed, 120 insertions(+), 66 deletions(-) diff --git a/src/Gateway.sol b/src/Gateway.sol index b88be45..d571e82 100644 --- a/src/Gateway.sol +++ b/src/Gateway.sol @@ -78,10 +78,6 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // GMP message status mapping(bytes32 => GmpInfo) private _messages; - // GAP necessary for migration purposes - mapping(GmpSender => mapping(uint16 => uint256)) private _deprecated_Deposits; - mapping(uint16 => bytes32) private _deprecated_Networks; - // Hash of the previous GMP message submitted. bytes32 public prevMessageHash; @@ -179,7 +175,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { function keyInfo(bytes32 id) external view returns (ShardStore.KeyInfo memory) { ShardStore.MainStorage storage store = ShardStore.getMainStorage(); - return store.get(id); + return store.get(ShardStore.ShardID.wrap(id)); } function networkId() external view returns (uint16) { @@ -195,7 +191,11 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { */ function _verifySignature(Signature calldata signature, bytes32 message) private view { // Load shard from storage - KeyInfo storage signer = _shards[bytes32(signature.xCoord)]; + ShardStore.KeyInfo storage signer; + { + ShardStore.MainStorage storage store = ShardStore.getMainStorage(); + signer = store.get(ShardStore.ShardID.wrap(bytes32(signature.xCoord))); + } // Verify if shard is active uint8 status = signer.status; @@ -212,11 +212,11 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { ); } - // Converts a `TssKey` into an `KeyInfo` unique identifier - function _tssKeyToShardId(TssKey memory tssKey) private pure returns (bytes32) { + // Converts a `TssKey` into an `ShardStore.ShardID` unique identifier + function _tssKeyToShardId(TssKey memory tssKey) private pure returns (ShardStore.ShardID) { // The tssKey coord x is already collision resistant // if we are unsure about it, we can hash the coord and parity bit - return bytes32(tssKey.xCoord); + return ShardStore.ShardID.wrap(bytes32(tssKey.xCoord)); } // Initialize networks @@ -236,13 +236,14 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { function _registerKeys(TssKey[] memory keys) private { // We don't perform any arithmetic operation, except iterate a loop unchecked { + ShardStore.MainStorage storage store = ShardStore.getMainStorage(); + // Register or activate tss key (revoked keys keep the previous nonce) for (uint256 i = 0; i < keys.length; i++) { TssKey memory newKey = keys[i]; // Read shard from storage - bytes32 shardId = _tssKeyToShardId(newKey); - KeyInfo storage shard = _shards[shardId]; + ShardStore.KeyInfo storage shard = store.get(_tssKeyToShardId(newKey)); uint8 status = shard.status; uint32 nonce = shard.nonce; @@ -276,44 +277,17 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } } - // Revoke TSS keys - function _revokeKeys(TssKey[] memory keys) private { - // We don't perform any arithmetic operation, except iterate a loop - unchecked { - // Revoke tss keys - for (uint256 i = 0; i < keys.length; i++) { - TssKey memory revokedKey = keys[i]; - - // Read shard from storage - bytes32 shardId = _tssKeyToShardId(revokedKey); - KeyInfo storage shard = _shards[shardId]; - - // Check if the shard exists and is active - require(shard.nonce > 0, "shard doesn't exists, cannot revoke key"); - require((shard.status & SHARD_ACTIVE) > 0, "cannot revoke a shard key already revoked"); - - // Check y-parity - { - uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0; - require(yParity == revokedKey.yParity, "invalid y parity bit, cannot revoke key"); - } - - // Disable SHARD_ACTIVE bitflag - shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag - } - } - } - // Register/Revoke TSS keys and emits [`KeySetChanged`] event function _updateKeys(bytes32 messageHash, TssKey[] memory keysToRevoke, TssKey[] memory newKeys) private { - // We don't perform any arithmetic operation, except iterate a loop - unchecked { - // Revoke tss keys (revoked keys can be registred again keeping the previous nonce) - _revokeKeys(keysToRevoke); + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); - // Register or activate revoked keys - _registerKeys(newKeys); - } + // Revoke tss keys (revoked keys can be registred again keeping the previous nonce) + shards.revokeKeys(keysToRevoke); + + // Register or activate revoked keys + shards.registerTssKeys(newKeys); + + // Emit event emit KeySetChanged(messageHash, keysToRevoke, newKeys); } @@ -670,17 +644,18 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } // OBS: remove != revoke (when revoked, you cannot register again) - function sudoRemoveShards(TssKey[] memory shards) external payable { + function sudoRemoveShards(TssKey[] memory revokedKeys) external payable { require(msg.sender == _getAdmin(), "unauthorized"); - for (uint256 i; i < shards.length; i++) { - bytes32 shardId = _tssKeyToShardId(shards[i]); - delete _shards[shardId]; - } + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); + shards.revokeKeys(revokedKeys); + emit KeySetChanged(bytes32(0), revokedKeys, new TssKey[](0)); } - function sudoAddShards(TssKey[] memory shards) external payable { + function sudoAddShards(TssKey[] memory newKeys) external payable { require(msg.sender == _getAdmin(), "unauthorized"); - _registerKeys(shards); + ShardStore.MainStorage storage shards = ShardStore.getMainStorage(); + shards.registerTssKeys(newKeys); + emit KeySetChanged(bytes32(0), new TssKey[](0), newKeys); } // DANGER: This function is for migration purposes only, it allows the admin to set any storage slot. diff --git a/src/storage/Shards.sol b/src/storage/Shards.sol index 7878261..b87c90f 100644 --- a/src/storage/Shards.sol +++ b/src/storage/Shards.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.20; import {TssKey} from "../Primitives.sol"; import {EnumerableSet, Pointer} from "../utils/EnumerableSet.sol"; +import {BranchlessMath} from "../utils/BranchlessMath.sol"; import {StoragePtr} from "../utils/Pointer.sol"; /** @@ -20,6 +21,9 @@ library ShardStore { */ bytes32 internal constant _EIP7201_NAMESPACE = 0x582bcdebbeef4fb96dde802cfe96e9942657f4bedb5cfe94e8786bb683eb1f00; + uint8 internal constant SHARD_ACTIVE = (1 << 0); // Shard active bitflag + uint8 internal constant SHARD_Y_PARITY = (1 << 1); // Pubkey y parity bitflag + /** * @dev Shard ID, this is the xCoord of the TssKey */ @@ -44,7 +48,7 @@ library ShardStore { */ struct KeyInfo { uint216 _gap; - ShardStatus status; + uint8 status; uint32 nonce; } @@ -62,6 +66,7 @@ library ShardStore { error ShardAlreadyRegistered(ShardID id); error ShardNotExists(ShardID id); + error IndexOutOfBounds(uint256 index); function getMainStorage() internal pure returns (MainStorage storage $) { assembly { @@ -101,15 +106,15 @@ library ShardStore { * Returns true if the value was added to the set, that is if it was not * already present. */ - function add(MainStorage storage store, TssKey memory shard) internal returns (bool) { - StoragePtr ptr = store.shards.add(bytes32(shard.xCoord)); + function set(MainStorage storage store, ShardID xCoord, KeyInfo memory shard) internal returns (bool) { + StoragePtr ptr = store.shards.add(ShardID.unwrap(xCoord)); if (ptr.isNull()) { return false; } KeyInfo storage keyInfo = _getKeyInfo(ptr); - keyInfo._gap = 0; - keyInfo.status = ShardStatus.Active; - keyInfo.nonce = 1; + keyInfo._gap = shard._gap; + keyInfo.status = shard.status; + keyInfo.nonce = shard.nonce; return true; } @@ -126,14 +131,14 @@ library ShardStore { } KeyInfo storage keyInfo = _getKeyInfo(ptr); keyInfo._gap = 0; - keyInfo.status = ShardStatus.Revoked; + keyInfo.status &= ~SHARD_ACTIVE; return true; } /** * @dev Returns the number of values on the set. O(1). */ - function _length(MainStorage storage store) private view returns (uint256) { + function length(MainStorage storage store) internal view returns (uint256) { return store.shards.length(); } @@ -149,7 +154,9 @@ library ShardStore { */ function at(MainStorage storage store, uint256 index) internal view returns (KeyInfo storage) { StoragePtr ptr = store.shards.at(index); - require(ptr.isNull() == false, "ShardStore: index out of bounds"); + if (ptr.isNull()) { + revert IndexOutOfBounds(index); + } return _getKeyInfo(ptr); } @@ -160,12 +167,76 @@ library ShardStore { * * - `key` must be in the map. */ - function get(MainStorage storage store, bytes32 key) internal view returns (KeyInfo storage) { - StoragePtr ptr = store.shards.get(key); - require(ptr.isNull() == false, "ShardStore: key not found"); + function get(MainStorage storage store, ShardID key) internal view returns (KeyInfo storage) { + StoragePtr ptr = store.shards.get(ShardID.unwrap(key)); + if (ptr.isNull()) { + revert ShardNotExists(key); + } return _getKeyInfo(ptr); } + /** + * @dev Returns the value associated with `key`. O(1). + */ + function tryGet(MainStorage storage store, ShardID key) internal view returns (bool, KeyInfo storage) { + StoragePtr ptr = store.shards.get(ShardID.unwrap(key)); + return (ptr.isNull(), _getKeyInfo(ptr)); + } + + function registerTssKeys(ShardStore.MainStorage storage store, TssKey[] memory keys) internal { + // We don't perform any arithmetic operation, except iterate a loop + unchecked { + // Register or activate tss key (revoked keys keep the previous nonce) + for (uint256 i = 0; i < keys.length; i++) { + TssKey memory newKey = keys[i]; + require(newKey.yParity == (newKey.yParity & 1), "y parity bit must be 0 or 1, cannot register shard"); + + ShardID id = ShardID.wrap(bytes32(newKey.xCoord)); + KeyInfo storage shard = _getKeyInfo(store.shards.getUnchecked(ShardID.unwrap(id))); + + // Check if the shard is already registered + if (store.shards.add(ShardID.unwrap(id)).isNull()) { + revert ShardAlreadyRegistered(id); + } + + shard.status = BranchlessMath.ternaryU8(newKey.yParity > 0, 0, SHARD_Y_PARITY) | SHARD_ACTIVE; + shard.nonce += uint32(BranchlessMath.toUint(shard.nonce == 0)); + } + } + } + + // Revoke TSS keys + function revokeKeys(ShardStore.MainStorage storage store, TssKey[] memory keys) internal { + // We don't perform any arithmetic operation, except iterate a loop + unchecked { + // Revoke tss keys + for (uint256 i = 0; i < keys.length; i++) { + TssKey memory revokedKey = keys[i]; + + // Read shard from storage + ShardID id = ShardID.wrap(bytes32(revokedKey.xCoord)); + KeyInfo storage shard; + { + bool shardExists; + (shardExists, shard) = tryGet(store, id); + + if (!shardExists || shard.nonce == 0) { + revert ShardNotExists(id); + } + } + + // Check y-parity + { + uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0; + require(yParity == revokedKey.yParity, "y parity bit mismatch, cannot revoke key"); + } + + // Disable SHARD_ACTIVE bitflag + shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag + } + } + } + // /** // * @dev Return the entire set in an array // * diff --git a/src/utils/EnumerableSet.sol b/src/utils/EnumerableSet.sol index a1fcd11..108b962 100644 --- a/src/utils/EnumerableSet.sol +++ b/src/utils/EnumerableSet.sol @@ -175,6 +175,14 @@ library EnumerableSet { } } + function getUnchecked(Map storage map, bytes32 key) internal pure returns (StoragePtr r) { + assembly ("memory-safe") { + mstore(0x00, key) + mstore(0x20, add(map.slot, 1)) + r := keccak256(0x00, 0x40) + } + } + // /** // * @dev Return the entire set in an array // * diff --git a/test/EnumerableSet.t.sol b/test/EnumerableSet.t.sol index 0a53055..00ac341 100644 --- a/test/EnumerableSet.t.sol +++ b/test/EnumerableSet.t.sol @@ -168,7 +168,7 @@ contract EnumerableSetTest is Test { /** * Test if `Map.add` and `Map.at` work as expected. */ - function test_fuzzz() external { + function test_fuzz() external { // bytes32 key, uint256 value bytes32 key = bytes32(0); uint256 value = 256;