Skip to content

Commit

Permalink
Move shard logic to shard storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Lohann committed Nov 4, 2024
1 parent 824ccdc commit b80ab5f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 66 deletions.
81 changes: 28 additions & 53 deletions src/Gateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
// GMP message status
mapping(bytes32 => GmpInfo) private _messages;

// GAP necessary for migration purposes
mapping(GmpSender => mapping(uint16 => uint256)) private _deprecated_Deposits;
mapping(uint16 => bytes32) private _deprecated_Networks;

// Hash of the previous GMP message submitted.
bytes32 public prevMessageHash;

Expand Down Expand Up @@ -179,7 +175,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {

function keyInfo(bytes32 id) external view returns (ShardStore.KeyInfo memory) {
ShardStore.MainStorage storage store = ShardStore.getMainStorage();
return store.get(id);
return store.get(ShardStore.ShardID.wrap(id));
}

function networkId() external view returns (uint16) {
Expand All @@ -195,7 +191,11 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
*/
function _verifySignature(Signature calldata signature, bytes32 message) private view {
// Load shard from storage
KeyInfo storage signer = _shards[bytes32(signature.xCoord)];
ShardStore.KeyInfo storage signer;
{
ShardStore.MainStorage storage store = ShardStore.getMainStorage();
signer = store.get(ShardStore.ShardID.wrap(bytes32(signature.xCoord)));
}

// Verify if shard is active
uint8 status = signer.status;
Expand All @@ -212,11 +212,11 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
);
}

// Converts a `TssKey` into an `KeyInfo` unique identifier
function _tssKeyToShardId(TssKey memory tssKey) private pure returns (bytes32) {
// Converts a `TssKey` into an `ShardStore.ShardID` unique identifier
function _tssKeyToShardId(TssKey memory tssKey) private pure returns (ShardStore.ShardID) {
// The tssKey coord x is already collision resistant
// if we are unsure about it, we can hash the coord and parity bit
return bytes32(tssKey.xCoord);
return ShardStore.ShardID.wrap(bytes32(tssKey.xCoord));
}

// Initialize networks
Expand All @@ -236,13 +236,14 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
function _registerKeys(TssKey[] memory keys) private {
// We don't perform any arithmetic operation, except iterate a loop
unchecked {
ShardStore.MainStorage storage store = ShardStore.getMainStorage();

// Register or activate tss key (revoked keys keep the previous nonce)
for (uint256 i = 0; i < keys.length; i++) {
TssKey memory newKey = keys[i];

// Read shard from storage
bytes32 shardId = _tssKeyToShardId(newKey);
KeyInfo storage shard = _shards[shardId];
ShardStore.KeyInfo storage shard = store.get(_tssKeyToShardId(newKey));
uint8 status = shard.status;
uint32 nonce = shard.nonce;

Expand Down Expand Up @@ -276,44 +277,17 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
}
}

// Revoke TSS keys
function _revokeKeys(TssKey[] memory keys) private {
// We don't perform any arithmetic operation, except iterate a loop
unchecked {
// Revoke tss keys
for (uint256 i = 0; i < keys.length; i++) {
TssKey memory revokedKey = keys[i];

// Read shard from storage
bytes32 shardId = _tssKeyToShardId(revokedKey);
KeyInfo storage shard = _shards[shardId];

// Check if the shard exists and is active
require(shard.nonce > 0, "shard doesn't exists, cannot revoke key");
require((shard.status & SHARD_ACTIVE) > 0, "cannot revoke a shard key already revoked");

// Check y-parity
{
uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0;
require(yParity == revokedKey.yParity, "invalid y parity bit, cannot revoke key");
}

// Disable SHARD_ACTIVE bitflag
shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag
}
}
}

// Register/Revoke TSS keys and emits [`KeySetChanged`] event
function _updateKeys(bytes32 messageHash, TssKey[] memory keysToRevoke, TssKey[] memory newKeys) private {
// We don't perform any arithmetic operation, except iterate a loop
unchecked {
// Revoke tss keys (revoked keys can be registred again keeping the previous nonce)
_revokeKeys(keysToRevoke);
ShardStore.MainStorage storage shards = ShardStore.getMainStorage();

// Register or activate revoked keys
_registerKeys(newKeys);
}
// Revoke tss keys (revoked keys can be registred again keeping the previous nonce)
shards.revokeKeys(keysToRevoke);

// Register or activate revoked keys
shards.registerTssKeys(newKeys);

// Emit event
emit KeySetChanged(messageHash, keysToRevoke, newKeys);
}

Expand Down Expand Up @@ -670,17 +644,18 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
}

// OBS: remove != revoke (when revoked, you cannot register again)
function sudoRemoveShards(TssKey[] memory shards) external payable {
function sudoRemoveShards(TssKey[] memory revokedKeys) external payable {
require(msg.sender == _getAdmin(), "unauthorized");
for (uint256 i; i < shards.length; i++) {
bytes32 shardId = _tssKeyToShardId(shards[i]);
delete _shards[shardId];
}
ShardStore.MainStorage storage shards = ShardStore.getMainStorage();
shards.revokeKeys(revokedKeys);
emit KeySetChanged(bytes32(0), revokedKeys, new TssKey[](0));
}

function sudoAddShards(TssKey[] memory shards) external payable {
function sudoAddShards(TssKey[] memory newKeys) external payable {
require(msg.sender == _getAdmin(), "unauthorized");
_registerKeys(shards);
ShardStore.MainStorage storage shards = ShardStore.getMainStorage();
shards.registerTssKeys(newKeys);
emit KeySetChanged(bytes32(0), new TssKey[](0), newKeys);
}

// DANGER: This function is for migration purposes only, it allows the admin to set any storage slot.
Expand Down
95 changes: 83 additions & 12 deletions src/storage/Shards.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.20;

import {TssKey} from "../Primitives.sol";
import {EnumerableSet, Pointer} from "../utils/EnumerableSet.sol";
import {BranchlessMath} from "../utils/BranchlessMath.sol";
import {StoragePtr} from "../utils/Pointer.sol";

/**
Expand All @@ -20,6 +21,9 @@ library ShardStore {
*/
bytes32 internal constant _EIP7201_NAMESPACE = 0x582bcdebbeef4fb96dde802cfe96e9942657f4bedb5cfe94e8786bb683eb1f00;

uint8 internal constant SHARD_ACTIVE = (1 << 0); // Shard active bitflag
uint8 internal constant SHARD_Y_PARITY = (1 << 1); // Pubkey y parity bitflag

/**
* @dev Shard ID, this is the xCoord of the TssKey
*/
Expand All @@ -44,7 +48,7 @@ library ShardStore {
*/
struct KeyInfo {
uint216 _gap;
ShardStatus status;
uint8 status;
uint32 nonce;
}

Expand All @@ -62,6 +66,7 @@ library ShardStore {

error ShardAlreadyRegistered(ShardID id);
error ShardNotExists(ShardID id);
error IndexOutOfBounds(uint256 index);

function getMainStorage() internal pure returns (MainStorage storage $) {
assembly {
Expand Down Expand Up @@ -101,15 +106,15 @@ library ShardStore {
* Returns true if the value was added to the set, that is if it was not
* already present.
*/
function add(MainStorage storage store, TssKey memory shard) internal returns (bool) {
StoragePtr ptr = store.shards.add(bytes32(shard.xCoord));
function set(MainStorage storage store, ShardID xCoord, KeyInfo memory shard) internal returns (bool) {
StoragePtr ptr = store.shards.add(ShardID.unwrap(xCoord));
if (ptr.isNull()) {
return false;
}
KeyInfo storage keyInfo = _getKeyInfo(ptr);
keyInfo._gap = 0;
keyInfo.status = ShardStatus.Active;
keyInfo.nonce = 1;
keyInfo._gap = shard._gap;
keyInfo.status = shard.status;
keyInfo.nonce = shard.nonce;
return true;
}

Expand All @@ -126,14 +131,14 @@ library ShardStore {
}
KeyInfo storage keyInfo = _getKeyInfo(ptr);
keyInfo._gap = 0;
keyInfo.status = ShardStatus.Revoked;
keyInfo.status &= ~SHARD_ACTIVE;
return true;
}

/**
* @dev Returns the number of values on the set. O(1).
*/
function _length(MainStorage storage store) private view returns (uint256) {
function length(MainStorage storage store) internal view returns (uint256) {
return store.shards.length();
}

Expand All @@ -149,7 +154,9 @@ library ShardStore {
*/
function at(MainStorage storage store, uint256 index) internal view returns (KeyInfo storage) {
StoragePtr ptr = store.shards.at(index);
require(ptr.isNull() == false, "ShardStore: index out of bounds");
if (ptr.isNull()) {
revert IndexOutOfBounds(index);
}
return _getKeyInfo(ptr);
}

Expand All @@ -160,12 +167,76 @@ library ShardStore {
*
* - `key` must be in the map.
*/
function get(MainStorage storage store, bytes32 key) internal view returns (KeyInfo storage) {
StoragePtr ptr = store.shards.get(key);
require(ptr.isNull() == false, "ShardStore: key not found");
function get(MainStorage storage store, ShardID key) internal view returns (KeyInfo storage) {
StoragePtr ptr = store.shards.get(ShardID.unwrap(key));
if (ptr.isNull()) {
revert ShardNotExists(key);
}
return _getKeyInfo(ptr);
}

/**
* @dev Returns the value associated with `key`. O(1).
*/
function tryGet(MainStorage storage store, ShardID key) internal view returns (bool, KeyInfo storage) {
StoragePtr ptr = store.shards.get(ShardID.unwrap(key));
return (ptr.isNull(), _getKeyInfo(ptr));
}

function registerTssKeys(ShardStore.MainStorage storage store, TssKey[] memory keys) internal {
// We don't perform any arithmetic operation, except iterate a loop
unchecked {
// Register or activate tss key (revoked keys keep the previous nonce)
for (uint256 i = 0; i < keys.length; i++) {
TssKey memory newKey = keys[i];
require(newKey.yParity == (newKey.yParity & 1), "y parity bit must be 0 or 1, cannot register shard");

ShardID id = ShardID.wrap(bytes32(newKey.xCoord));
KeyInfo storage shard = _getKeyInfo(store.shards.getUnchecked(ShardID.unwrap(id)));

// Check if the shard is already registered
if (store.shards.add(ShardID.unwrap(id)).isNull()) {
revert ShardAlreadyRegistered(id);
}

shard.status = BranchlessMath.ternaryU8(newKey.yParity > 0, 0, SHARD_Y_PARITY) | SHARD_ACTIVE;
shard.nonce += uint32(BranchlessMath.toUint(shard.nonce == 0));
}
}
}

// Revoke TSS keys
function revokeKeys(ShardStore.MainStorage storage store, TssKey[] memory keys) internal {
// We don't perform any arithmetic operation, except iterate a loop
unchecked {
// Revoke tss keys
for (uint256 i = 0; i < keys.length; i++) {
TssKey memory revokedKey = keys[i];

// Read shard from storage
ShardID id = ShardID.wrap(bytes32(revokedKey.xCoord));
KeyInfo storage shard;
{
bool shardExists;
(shardExists, shard) = tryGet(store, id);

if (!shardExists || shard.nonce == 0) {
revert ShardNotExists(id);
}
}

// Check y-parity
{
uint8 yParity = (shard.status & SHARD_Y_PARITY) > 0 ? 1 : 0;
require(yParity == revokedKey.yParity, "y parity bit mismatch, cannot revoke key");
}

// Disable SHARD_ACTIVE bitflag
shard.status = shard.status & (~SHARD_ACTIVE); // Disable active flag
}
}
}

// /**
// * @dev Return the entire set in an array
// *
Expand Down
8 changes: 8 additions & 0 deletions src/utils/EnumerableSet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ library EnumerableSet {
}
}

function getUnchecked(Map storage map, bytes32 key) internal pure returns (StoragePtr r) {
assembly ("memory-safe") {
mstore(0x00, key)
mstore(0x20, add(map.slot, 1))
r := keccak256(0x00, 0x40)
}
}

// /**
// * @dev Return the entire set in an array
// *
Expand Down
2 changes: 1 addition & 1 deletion test/EnumerableSet.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ contract EnumerableSetTest is Test {
/**
* Test if `Map.add` and `Map.at` work as expected.
*/
function test_fuzzz() external {
function test_fuzz() external {
// bytes32 key, uint256 value
bytes32 key = bytes32(0);
uint256 value = 256;
Expand Down

0 comments on commit b80ab5f

Please sign in to comment.