diff --git a/src/Gateway.sol b/src/Gateway.sol index a10acde..b7e38a3 100644 --- a/src/Gateway.sol +++ b/src/Gateway.sol @@ -72,20 +72,24 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { bytes32 internal constant FIRST_MESSAGE_PLACEHOLDER = bytes32(uint256(2 ** 256 - 1)); // Shard data, maps the pubkey coordX (which is already collision resistant) to shard info. - mapping(bytes32 => KeyInfo) _shards; + mapping(bytes32 => KeyInfo) private _shards; // GMP message status - mapping(bytes32 => GmpInfo) _messages; + mapping(bytes32 => GmpInfo) private _messages; + + // GAP necessary for migration purposes + mapping(GmpSender => mapping(uint16 => uint256)) private _deprecated_Deposits; + mapping(uint16 => bytes32) private _deprecated_Networks; // Hash of the previous GMP message submitted. bytes32 public prevMessageHash; // Replay protection mechanism, stores the hash of the executed messages // messageHash => shardId - mapping(bytes32 => bytes32) _executedMessages; + mapping(bytes32 => bytes32) private _executedMessages; // Network ID => Source network - mapping(uint16 => NetworkInfo) _networkInfo; + mapping(uint16 => NetworkInfo) private _networkInfo; /** * @dev Shard info stored in the Gateway Contract @@ -177,6 +181,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { return NETWORK_ID; } + function networkInfo(uint16 id) external view returns (NetworkInfo memory) { + return _networkInfo[id]; + } + /** * @dev Verify if shard exists, if the TSS signature is valid then increment shard's nonce. */ @@ -190,7 +198,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // Load y parity bit, it must be 27 (even), or 28 (odd) // ref: https://ethereum.github.io/yellowpaper/paper.pdf - uint8 yParity = uint8(BranchlessMath.select((status & SHARD_Y_PARITY) > 0, 28, 27)); + uint8 yParity = BranchlessMath.ternaryU8((status & SHARD_Y_PARITY) > 0, 28, 27); // Verify Signature require( @@ -206,13 +214,14 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { return bytes32(tssKey.xCoord); } - // Converts a `TssKey` into an `KeyInfo` unique identifier + // Initialize networks function _updateNetworks(Network[] calldata networks) private { for (uint256 i = 0; i < networks.length; i++) { Network calldata network = networks[i]; - bytes32 domainSeparator = computeDomainSeparator(network.id, network.gateway); NetworkInfo storage info = _networkInfo[network.id]; - info.domainSeparator = domainSeparator; + require(info.domainSeparator == bytes32(0), "network already initialized"); + require(network.id != NETWORK_ID || network.gateway == address(this), "wrong gateway address"); + info.domainSeparator = computeDomainSeparator(network.id, network.gateway); info.gasLimit = 15_000_000; // Default to 15M gas info.relativeGasPrice = UFloatMath.ONE; info.baseFee = 0; @@ -248,10 +257,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { ); // if is a new shard shard, set its initial nonce to 1 - shard.nonce = uint32(BranchlessMath.select(nonce == 0, 1, nonce)); + shard.nonce = uint32(BranchlessMath.ternaryU32(nonce == 0, 1, nonce)); // enable/disable the y-parity flag - status = uint8(BranchlessMath.select(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY)); + status = BranchlessMath.ternaryU8(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY); // enable SHARD_ACTIVE bitflag status |= SHARD_ACTIVE; @@ -340,7 +349,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // https://eips.ethereum.org/EIPS/eip-150 uint256 gasNeeded = gasLimit.saturatingMul(64).saturatingDiv(63); // to guarantee it was provided enough gas to execute the GMP message - gasNeeded = gasNeeded.saturatingAdd(6412); + gasNeeded = gasNeeded.saturatingAdd(10000); require(gasleft() >= gasNeeded, "insufficient gas to execute GMP message"); } @@ -374,7 +383,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } // Update GMP status - status = GmpStatus(BranchlessMath.select(success, uint256(GmpStatus.SUCCESS), uint256(GmpStatus.REVERT))); + status = GmpStatus(BranchlessMath.ternary(success, uint256(GmpStatus.SUCCESS), uint256(GmpStatus.REVERT))); // Persist gmp execution status on storage gmp.status = status; @@ -413,11 +422,9 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { unchecked { // Compute GMP gas used uint256 gasUsed = 7214; - { - gasUsed = gasUsed.saturatingAdd(GasUtils.calldataGasCost()); - gasUsed = gasUsed.saturatingAdd(GasUtils.proxyOverheadGasCost(uint16(msg.data.length), 64)); - gasUsed = gasUsed.saturatingAdd(initialGas - gasleft()); - } + gasUsed = gasUsed.saturatingAdd(GasUtils.calldataBaseCost()); + gasUsed = gasUsed.saturatingAdd(GasUtils.proxyOverheadGasCost(uint16(msg.data.length), 64)); + gasUsed = gasUsed.saturatingAdd(initialGas - gasleft()); // Compute refund amount uint256 refund = BranchlessMath.min(gasUsed.saturatingMul(tx.gasprice), address(this).balance); @@ -429,34 +436,48 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { } } - function _setNetworkInfo(bytes32 executor, bytes32 messageHash, UpdateNetworkInfo calldata data) private { - require(data.mortality >= block.number, "message expired"); + function _setNetworkInfo(bytes32 executor, bytes32 messageHash, UpdateNetworkInfo calldata info) private { + require(info.mortality >= block.number, "message expired"); require(executor != bytes32(0), "executor cannot be zero"); // Verify signature and if the message was already executed require(_executedMessages[messageHash] == bytes32(0), "message already executed"); - // Update network info and store the message hash to prevent replay attacks - NetworkInfo storage networkInfo = _networkInfo[data.networkId]; - - // Verify if the domain separator is not zero - require((networkInfo.domainSeparator | data.domainSeparator) != bytes32(0), "domain separator cannot be zero"); + // Update network info + NetworkInfo memory stored = _networkInfo[info.networkId]; - // Update domain separator if it's not zero - if (data.domainSeparator != bytes32(0)) { - networkInfo.domainSeparator = messageHash; - } + // Verify and update domain separator if it's not zero + stored.domainSeparator = + BranchlessMath.ternary(info.domainSeparator != bytes32(0), info.domainSeparator, stored.domainSeparator); + require(stored.domainSeparator != bytes32(0), "domain separator cannot be zero"); // Update gas limit if it's not zero - networkInfo.gasLimit = uint64(BranchlessMath.select(data.gasLimit > 0, data.gasLimit, networkInfo.gasLimit)); - if (UFloat9x56.unwrap(data.relativeGasPrice) > 0 || data.baseFee > 0) { - networkInfo.relativeGasPrice = networkInfo.relativeGasPrice; - networkInfo.baseFee = networkInfo.baseFee; + stored.gasLimit = BranchlessMath.ternaryU64(info.gasLimit > 0, info.gasLimit, stored.gasLimit); + + // Update relative gas price and base fee if any of them are greater than zero + { + bool shouldUpdate = UFloat9x56.unwrap(info.relativeGasPrice) > 0 || info.baseFee > 0; + stored.relativeGasPrice = UFloat9x56.wrap( + BranchlessMath.ternaryU64( + shouldUpdate, UFloat9x56.unwrap(info.relativeGasPrice), UFloat9x56.unwrap(stored.relativeGasPrice) + ) + ); + stored.baseFee = BranchlessMath.ternaryU128(shouldUpdate, info.baseFee, stored.baseFee); } + + // Save the message hash to prevent replay attacks _executedMessages[messageHash] = executor; + // Update network info + _networkInfo[info.networkId] = stored; + emit NetworkUpdated( - messageHash, data.networkId, data.domainSeparator, data.relativeGasPrice, data.baseFee, data.gasLimit + messageHash, + info.networkId, + stored.domainSeparator, + stored.relativeGasPrice, + stored.baseFee, + stored.gasLimit ); } @@ -494,10 +515,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { address destinationAddress, uint16 destinationNetwork, uint256 executionGasLimit, - bytes memory data + bytes calldata data ) external payable returns (bytes32) { // Check if the message data is too large - require(data.length < MAX_PAYLOAD_SIZE, "msg data too large"); + require(data.length <= MAX_PAYLOAD_SIZE, "msg data too large"); // Check if the destination network is supported NetworkInfo storage info = _networkInfo[destinationNetwork]; @@ -506,7 +527,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { // Check if the sender has deposited enougth funds to execute the GMP message { - uint256 nonZeros = GasUtils.countNonZeros(data); + uint256 nonZeros = GasUtils.countNonZerosCalldata(data); uint256 zeros = data.length - nonZeros; uint256 msgPrice = GasUtils.estimateWeiCost( info.relativeGasPrice, info.baseFee, uint16(nonZeros), uint16(zeros), executionGasLimit @@ -521,18 +542,37 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { bytes32 prevHash = prevMessageHash; // if the messageHash is the first message, we use a zero salt - uint256 salt = BranchlessMath.select(prevHash == FIRST_MESSAGE_PLACEHOLDER, 0, uint256(prevHash)); + uint256 salt = BranchlessMath.ternary(prevHash == FIRST_MESSAGE_PLACEHOLDER, 0, uint256(prevHash)); // Create GMP message and update prevMessageHash - GmpMessage memory message = - GmpMessage(source, NETWORK_ID, destinationAddress, destinationNetwork, executionGasLimit, salt, data); - prevHash = message.eip712TypedHash(domainSeparator); - prevMessageHash = prevHash; + bytes memory payload; + { + GmpMessage memory message = + GmpMessage(source, NETWORK_ID, destinationAddress, destinationNetwork, executionGasLimit, salt, data); + prevHash = message.eip712TypedHash(domainSeparator); + prevMessageHash = prevHash; + payload = message.data; + } - emit GmpCreated( - prevHash, GmpSender.unwrap(source), destinationAddress, destinationNetwork, executionGasLimit, salt, data - ); - return prevHash; + // Emit `GmpCreated` event without copy the data, to simplify the gas estimation. + // the assembly code below is equivalent to: + // ```solidity + // emit GmpCreated(prevHash, source, destinationAddress, destinationNetwork, executionGasLimit, salt, data); + // return prevHash; + // ``` + bytes32 eventSelector = GmpCreated.selector; + assembly { + let ptr := sub(payload, 0x80) + mstore(ptr, destinationNetwork) // dest network + mstore(add(ptr, 0x20), executionGasLimit) // gas limit + mstore(add(ptr, 0x40), salt) // salt + mstore(add(ptr, 0x60), 0x80) // data offset + let size := and(add(mload(payload), 31), 0xffffffe0) + size := add(size, 160) + log4(ptr, size, eventSelector, prevHash, source, destinationAddress) + mstore(0, prevHash) + return(0, 32) + } } /** @@ -555,7 +595,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { require(baseFee > 0 || UFloat9x56.unwrap(relativeGasPrice) > 0, "unsupported network"); // if the message data is too large, we use the maximum base fee. - baseFee = BranchlessMath.select(messageSize > MAX_PAYLOAD_SIZE, 2 ** 256 - 1, baseFee); + baseFee = BranchlessMath.ternary(messageSize > MAX_PAYLOAD_SIZE, 2 ** 256 - 1, baseFee); // Estimate the cost return GasUtils.estimateWeiCost(relativeGasPrice, baseFee, uint16(messageSize), 0, gasLimit); @@ -573,7 +613,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 { function _getAdmin() private view returns (address admin) { admin = ERC1967.getAdmin(); // If the admin slot is empty, then the 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6 is the admin - admin = BranchlessMath.select(admin == address(0x0), 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6, admin); + admin = BranchlessMath.ternary(admin == address(0x0), 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6, admin); } function setAdmin(address newAdmin) external payable { diff --git a/src/utils/BranchlessMath.sol b/src/utils/BranchlessMath.sol index de9c365..368f629 100644 --- a/src/utils/BranchlessMath.sol +++ b/src/utils/BranchlessMath.sol @@ -11,20 +11,20 @@ library BranchlessMath { * @dev Returns the smallest of two numbers. */ function min(uint256 x, uint256 y) internal pure returns (uint256) { - return select(x < y, x, y); + return ternary(x < y, x, y); } /** * @dev Returns the largest of two numbers. */ function max(uint256 x, uint256 y) internal pure returns (uint256) { - return select(x > y, x, y); + return ternary(x > y, x, y); } /** * @dev If `condition` is true returns `a`, otherwise returns `b`. */ - function select(bool condition, uint256 a, uint256 b) internal pure returns (uint256) { + function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256) { unchecked { // branchless select, works because: // b ^ (a ^ b) == a @@ -40,13 +40,67 @@ library BranchlessMath { /** * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` */ - function select(bool condition, address a, address b) internal pure returns (address) { - return address(uint160(select(condition, uint256(uint160(a)), uint256(uint160(b))))); + function ternary(bool condition, address a, address b) internal pure returns (address r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } + } + + /** + * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` + */ + function ternary(bool condition, bytes32 a, bytes32 b) internal pure returns (bytes32 r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } + } + + /** + * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` + */ + function ternaryU128(bool condition, uint128 a, uint128 b) internal pure returns (uint128 r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } + } + + /** + * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` + */ + function ternaryU64(bool condition, uint64 a, uint64 b) internal pure returns (uint64 r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } + } + + /** + * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` + */ + function ternaryU32(bool condition, uint32 a, uint32 b) internal pure returns (uint32 r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } + } + + /** + * @dev If `condition` is true returns `a`, otherwise returns `b`. + * see `BranchlessMath.ternary` + */ + function ternaryU8(bool condition, uint8 a, uint8 b) internal pure returns (uint8 r) { + assembly { + r := xor(b, mul(xor(a, b), condition)) + } } /** * @dev If `condition` is true return `value`, otherwise return zero. + * see `BranchlessMath.ternary` */ function selectIf(bool condition, uint256 value) internal pure returns (uint256) { unchecked { diff --git a/src/utils/GasUtils.sol b/src/utils/GasUtils.sol index 41edbaf..8179f6f 100644 --- a/src/utils/GasUtils.sol +++ b/src/utils/GasUtils.sol @@ -7,33 +7,79 @@ import {UFloat9x56, UFloatMath} from "./Float9x56.sol"; import {BranchlessMath} from "./BranchlessMath.sol"; /** - * @dev Utilities for branchless operations, useful when a constant gas cost is required. + * @dev Utilities for compute the GMP gas price, gas cost and gas needed. */ library GasUtils { - uint256 internal constant EXECUTION_BASE_COST = 39361 + 6700; + /** + * @dev Base cost of the `IExecutor.execute` method. + */ + uint256 internal constant EXECUTION_BASE_COST = 37647 + 6800; + + /** + * @dev Base cost of the `IGateway.submitMessage` method. + */ + uint256 internal constant SUBMIT_BASE_COST = 9640 + 6800 + 6500; using BranchlessMath for uint256; /** * @dev Compute the amount of gas used by the `GatewayProxy`. - * @param calldataLen The length of the calldata - * @param returnLen The length of the return data + * @param calldataLen The length of the calldata in bytes + * @param returnLen The length of the return data in bytes */ - function proxyOverheadGasCost(uint16 calldataLen, uint16 returnLen) internal pure returns (uint256) { + function proxyOverheadGasCost(uint256 calldataLen, uint256 returnLen) internal pure returns (uint256) { unchecked { + // Convert the calldata and return data length to words + calldataLen = calldataLen.saturatingAdd(31) >> 5; + returnLen = returnLen.saturatingAdd(31) >> 5; + // Base cost: OPCODES + COLD READ STORAGE _implementation uint256 gasCost = 2257 + 2500; // CALLDATACOPY - gasCost += ((uint256(calldataLen) + 31) >> 5) * 3; + gasCost = gasCost.saturatingAdd(calldataLen * 3); // RETURNDATACOPY - gasCost += ((uint256(returnLen) + 31) >> 5) * 3; + gasCost = gasCost.saturatingAdd(returnLen * 3); // MEMORY EXPANSION uint256 words = BranchlessMath.max(calldataLen, returnLen); - words = (words + 31) >> 5; + gasCost = gasCost.saturatingAdd((words.saturatingMul(words) >> 9).saturatingAdd(words * 3)); + return gasCost; + } + } + + /** + * @dev Compute the gas cost of the `IGateway.submitMessage` method. + * @param messageSize The size of the message in bytes. + */ + function submitMessageGasCost(uint16 messageSize) internal pure returns (uint256 gasCost) { + unchecked { + gasCost = SUBMIT_BASE_COST; + + // Convert message size to calldata size + uint256 calldataSize = ((messageSize + 31) & 0xffe0) + 164; + + // Proxy overhead + gasCost += proxyOverheadGasCost(uint16(calldataSize), 32); + + // `countNonZeros` gas cost + uint256 words = (messageSize + 31) >> 5; + gasCost += (words * 106) + (((words + 254) / 255) * 214); + + // CALLDATACOPY + gasCost += words * 3; + + // keccak256 (6 gas per word) + gasCost += words * 6; + + // emit GmpCreated() gas cost (8 gas per byte) + gasCost += words << 8; + + // Memory expansion cost + words += 13; gasCost += ((words * words) >> 9) + (words * 3); + return gasCost; } } @@ -99,7 +145,7 @@ library GasUtils { // Base cost uint256 words = (calldataSize + 31) >> 5; - gasNeeded += ((words - 1) / 15) * 1845; + gasNeeded += (words * 106) + (((words + 254) / 255) * 214); // CALLDATACOPY words = (messageSize + 31) >> 5; @@ -138,14 +184,16 @@ library GasUtils { /** * @dev Compute the gas that should be refunded to the executor for the execution. + * @param messageSize The size of the message. + * @param gasUsed The gas used by the gmp message. */ - function computeExecutionRefund(uint16 messageSize, uint256 gasLimit) + function computeExecutionRefund(uint16 messageSize, uint256 gasUsed) internal pure returns (uint256 executionCost) { // Add the base execution gas cost - executionCost = EXECUTION_BASE_COST.saturatingAdd(gasLimit); + executionCost = EXECUTION_BASE_COST.saturatingAdd(gasUsed); // Safety: The operations below can't overflow because the message size can't be greater than 2^16 unchecked { @@ -154,67 +202,31 @@ library GasUtils { // Proxy Overhead uint256 words = messagePadded + 388; // selector + Signature + GmpMessage - words = BranchlessMath.min(words, type(uint16).max); - executionCost += proxyOverheadGasCost(uint16(words), 64); + executionCost = executionCost.saturatingAdd(proxyOverheadGasCost(words, 64)); // Base Cost calculation words = (words + 31) >> 5; - executionCost += ((words - 1) / 15) * 1845; + executionCost = executionCost.saturatingAdd((words * 106) + (((words + 254) / 255) * 214)); // calldatacopy (3 gas per word) words = messagePadded >> 5; - executionCost += words * 3; + executionCost = executionCost.saturatingAdd(words * 3); // keccak256 (6 gas per word) - executionCost += words * 6; + executionCost = executionCost.saturatingAdd(words * 6); // Memory expansion cost words = 0xa4 + (words << 5); // onGmpReceived encoded call size words = (words + 31) & 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0; words += 0x0200; // Memory size words = (words + 31) >> 5; // to words - executionCost += ((words * words) >> 9) + (words * 3); + executionCost = executionCost.saturatingAdd(((words * words) >> 9) + (words * 3)); } } - /** - * @dev Compute the transaction base cost. - * OBS: This function must be used ONLY inside Gateway.execute method, because it also consider itself gas cost. - */ - function executionGasCost(uint256 messageSize) internal pure returns (uint256 baseCost, uint256 executionCost) { - // Calculate Gateway.execute dynamic cost - executionCost = EXECUTION_BASE_COST; - unchecked { - uint256 words = (messageSize + 31) & 0xffe0; - words += 388; - executionCost += proxyOverheadGasCost(uint16(words), 64); - - // Base Cost calculation - words = (words + 31) >> 5; - executionCost += ((words - 1) / 15) * 1845; - - // calldatacopy (3 gas per word) - words = (messageSize + 31) >> 5; - executionCost += words * 3; - - // keccak256 (6 gas per word) - executionCost += words * 6; - - // Memory expansion cost - words = 0xa4 + (words << 5); // onGmpReceived encoded call size - words = (words + 31) & 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0; - words += 0x0200; // Memory size - words = (words + 31) >> 5; // to words - executionCost += ((words * words) >> 9) + (words * 3); - } - - // Efficient algorithm for counting non-zero calldata bytes in chunks of 480 bytes at time - // computation gas cost = 1845 * ceil(msg.data.length / 480) + 61 - baseCost = calldataGasCost(); - } - /** * @dev Count the number of non-zero bytes in a byte sequence. + * gas cost = 217 + (words * 112) + ((words - 1) * 193) */ function countNonZeros(bytes memory data) internal pure returns (uint256 nonZeros) { /// @solidity memory-safe-assembly @@ -260,170 +272,56 @@ library GasUtils { } /** - * @dev Count non-zeros of a single 32 bytes word. + * @dev Count the number of non-zero bytes from calldata. + * gas cost = 224 + (words * 106) + (((words - 1) / 255) * 214) */ - function countNonZeros(bytes32 value) internal pure returns (uint256 nonZeros) { + function countNonZerosCalldata(bytes calldata data) internal pure returns (uint256 nonZeros) { /// @solidity memory-safe-assembly assembly { - // Normalize and count non-zero bytes in parallel - value := or(value, shr(4, value)) - value := or(value, shr(2, value)) - value := or(value, shr(1, value)) - value := and(value, 0x0101010101010101010101010101010101010101010101010101010101010101) - - // Sum bytes in parallel - value := add(value, shr(128, value)) - value := add(value, shr(64, value)) - value := add(value, shr(32, value)) - value := add(value, shr(16, value)) - value := add(value, shr(8, value)) - nonZeros := and(value, 0xff) - } - } - - /** - * @dev Compute the transaction base cost. - */ - function calldataGasCost() internal pure returns (uint256 baseCost) { - // Efficient algorithm for counting non-zero calldata bytes in chunks of 480 bytes at time - // computation gas cost = 1845 * ceil(msg.data.length / 480) + 61 - assembly { - baseCost := 0 + nonZeros := 0 for { - let ptr := 0 - let mask := 0x0101010101010101010101010101010101010101010101010101010101010101 - } lt(ptr, calldatasize()) { ptr := add(ptr, 32) } { - // 1 - let v := calldataload(ptr) - v := or(v, shr(4, v)) - v := or(v, shr(2, v)) - v := or(v, shr(1, v)) - v := and(v, mask) - { - // 2 - ptr := add(ptr, 32) + let ptr := data.offset + let end := add(ptr, data.length) + } lt(ptr, end) {} { + // calculate min(ptr + data.length, ptr + 8160) + let range := add(ptr, 8160) + range := xor(end, mul(xor(range, end), lt(range, end))) + + // Normalize and count non-zero bytes in parallel + let v := 0 + for {} lt(ptr, range) { ptr := add(ptr, 32) } { let r := calldataload(ptr) r := or(r, shr(4, r)) r := or(r, shr(2, r)) r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 3 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 4 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 5 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 6 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 7 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 8 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 9 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 10 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 11 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 12 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 13 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 14 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) - v := add(v, r) - // 15 - ptr := add(ptr, 32) - r := calldataload(ptr) - r := or(r, shr(4, r)) - r := or(r, shr(2, r)) - r := or(r, shr(1, r)) - r := and(r, mask) + r := and(r, 0x0101010101010101010101010101010101010101010101010101010101010101) v := add(v, r) } - // Count bytes in parallel + // Sum bytes in parallel + { + let l := and(v, 0x00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff) + v := shr(8, xor(v, l)) + v := add(v, l) + } v := add(v, shr(128, v)) v := add(v, shr(64, v)) v := add(v, shr(32, v)) v := add(v, shr(16, v)) v := and(v, 0xffff) - v := add(and(v, 0xff), shr(8, v)) - baseCost := add(baseCost, v) + nonZeros := add(nonZeros, v) } - baseCost := add(21000, add(mul(sub(calldatasize(), baseCost), 4), mul(baseCost, 16))) + } + } + + /** + * @dev Compute the transaction base cost. + */ + function calldataBaseCost() internal pure returns (uint256) { + unchecked { + uint256 nonZeros = countNonZerosCalldata(msg.data); + uint256 zeros = msg.data.length - nonZeros; + return 21000 + (nonZeros * 16) + (zeros * 4); } } } diff --git a/test/GasUtils.t.sol b/test/GasUtils.t.sol index 4232c39..165fe43 100644 --- a/test/GasUtils.t.sol +++ b/test/GasUtils.t.sol @@ -30,6 +30,18 @@ import { uint256 constant secret = 0x42; uint256 constant nonce = 0x69; +contract GasUtilsMock { + function execute(Signature calldata, GmpMessage calldata) + external + pure + returns (uint256 baseCost, uint256 nonZeros, uint256 zeros) + { + baseCost = GasUtils.calldataBaseCost(); + nonZeros = GasUtils.countNonZerosCalldata(msg.data); + zeros = msg.data.length - nonZeros; + } +} + contract GasUtilsBase is Test { using PrimitiveUtils for UpdateKeysMessage; using PrimitiveUtils for GmpMessage; @@ -38,6 +50,7 @@ contract GasUtilsBase is Test { using GatewayUtils for CallOptions; using BranchlessMath for uint256; + GasUtilsMock internal mock; Gateway internal gateway; Signer internal signer; @@ -56,6 +69,9 @@ contract GasUtilsBase is Test { address deployer = TestUtils.createTestAccount(100 ether); vm.startPrank(deployer, deployer); + // Deploy the GasUtilsMock contract + mock = new GasUtilsMock(); + // 1 - Deploy the implementation contract address proxyAddr = vm.computeCreateAddress(deployer, vm.getNonce(deployer) + 1); Gateway implementation = new Gateway(DEST_NETWORK_ID, proxyAddr); @@ -75,6 +91,14 @@ contract GasUtilsBase is Test { _srcDomainSeparator = GatewayUtils.computeDomainSeparator(SRC_NETWORK_ID, address(gateway)); _dstDomainSeparator = GatewayUtils.computeDomainSeparator(DEST_NETWORK_ID, address(gateway)); + // Obs: This is a special contract that wastes an exact amount of gas you send to it, helpful for testing GMP refunds and gas limits. + // See the file `HelperContract.opcode` for more details. + { + bytes memory bytecode = + hex"603c80600a5f395ff3fe5a600201803d523d60209160643560240135146018575bfd5b60365a116018575a604903565b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5bf3"; + receiver = IGmpReceiver(TestUtils.deployContract(bytecode)); + } + vm.stopPrank(); } @@ -91,42 +115,110 @@ contract GasUtilsBase is Test { } /** - * @dev Compare the estimated gas cost VS the actual gas cost of the `execute` method. + * @dev Create a GMP message with the provided parameters. */ - function test_baseExecutionCost(uint16 messageSize) external { - vm.assume(messageSize <= 0x6000); - vm.txGasPrice(1); - address sender = TestUtils.createTestAccount(100 ether); + function _buildGmpMessage(address sender, uint256 gasLimit, uint256 gasUsed, uint256 messageSize) + private + view + returns (GmpMessage memory message, Signature memory signature, CallOptions memory context) + { + require(gasUsed == 0 || messageSize >= 32, "If gasUsed > 0, then messageSize must be >= 32"); + require(messageSize <= 0x6000, "message is too big"); - // Build and sign GMP message - GmpMessage memory gmp = GmpMessage({ + // Setup data and receiver addresses. + bytes memory data = new bytes(messageSize); + address gmpReceiver; + if (gasUsed > 0) { + gmpReceiver = address(receiver); + assembly { + mstore(add(data, 32), gasUsed) + } + } else { + // Create a new unique receiver address for each message, otherwise the gas refund will not work. + gmpReceiver = address(bytes20(keccak256(abi.encode(sender, gasLimit, messageSize)))); + } + + // Build the GMP message + message = GmpMessage({ source: sender.toSender(false), srcNetwork: SRC_NETWORK_ID, - dest: address(bytes20(keccak256("dummy_address"))), + dest: gmpReceiver, destNetwork: DEST_NETWORK_ID, - gasLimit: 0, + gasLimit: gasLimit, salt: 0, - data: new bytes(messageSize) + data: data }); - Signature memory sig = sign(gmp); + // Sign the message + signature = sign(message); // Calculate memory expansion cost and base cost - (uint256 baseCost,) = GatewayUtils.computeGmpGasCost(sig, gmp); + (uint256 baseCost, uint256 executionCost) = GatewayUtils.computeGmpGasCost(signature, message); - // Transaction Parameters - CallOptions memory ctx = CallOptions({ + // Set Transaction Parameters + context = CallOptions({ from: sender, to: address(gateway), value: 0, - // gasLimit: GasUtils.executionGasNeeded(gmp.data.length, gmp.gasLimit) + baseCost, - gasLimit: baseCost + 1_000_000, - executionCost: 0, - baseCost: 0 + gasLimit: GasUtils.executionGasNeeded(message.data.length, message.gasLimit).saturatingAdd(baseCost), + executionCost: executionCost, + baseCost: baseCost + }); + } + + /** + * Test the `GasUtils.calldataBaseCost` method. + */ + function test_calldataBaseCost() external view { + // Build and sign GMP message + GmpMessage memory gmp = GmpMessage({ + source: address(0x1111111111111111111111111111111111111111).toSender(false), + srcNetwork: 1234, + dest: address(0x2222222222222222222222222222222222222222), + destNetwork: 1337, + gasLimit: 0, + salt: 0, + data: hex"00" }); + Signature memory sig = sign(gmp); + + // Check if `IExecutor.execute` match the expected base cost + (uint256 baseCost, uint256 nonZeros, uint256 zeros) = mock.execute(sig, gmp); + assertEq(baseCost, 24444, "Wrong calldata gas cost"); + assertEq(nonZeros, 147, "wrong number of non-zeros"); + assertEq(zeros, 273, "wrong number of zeros"); + } + + /** + * @dev Compare the estimated gas cost VS the actual gas cost of the `execute` method. + */ + function test_baseExecutionCost(uint16 messageSize, uint16 gasLimit) external { + vm.assume(gasLimit >= 5000); + vm.assume(messageSize <= (0x6000 - 32)); + messageSize += 32; + vm.txGasPrice(1); + address sender = TestUtils.createTestAccount(100 ether); + + // Build the GMP message + GmpMessage memory gmp; + Signature memory sig; + CallOptions memory ctx; + (gmp, sig, ctx) = _buildGmpMessage(sender, gasLimit, gasLimit, messageSize); + + // Increase the gas limit to avoid out-of-gas errors + ctx.gasLimit = ctx.gasLimit.saturatingAdd(10_000_000); // Execute the GMP message - ctx.execute(sig, gmp); + { + bytes32 gmpId = gmp.eip712TypedHash(_dstDomainSeparator); + vm.expectEmit(true, true, true, true); + emit IExecutor.GmpExecuted(gmpId, gmp.source, gmp.dest, GmpStatus.SUCCESS, bytes32(uint256(gasLimit))); + uint256 balanceBefore = ctx.from.balance; + (GmpStatus status, bytes32 result) = ctx.execute(sig, gmp); + assertEq(uint256(status), uint256(GmpStatus.SUCCESS), "GMP execution failed"); + assertEq(result, bytes32(uint256(gasLimit)), "unexpected result"); + assertEq(balanceBefore, ctx.from.balance, "Balance should not change"); + } // Calculate the expected base cost uint256 dynamicCost = @@ -136,22 +228,22 @@ contract GasUtilsBase is Test { } function test_gasUtils() external pure { - assertEq(GasUtils.estimateGas(0, 0, 0), 76208); - assertEq(GasUtils.estimateGas(0, 33, 0), 76369); - assertEq(GasUtils.estimateGas(33, 0, 0), 77029); - assertEq(GasUtils.estimateGas(20, 13, 0), 76769); + assertEq(GasUtils.estimateGas(0, 0, 0), 76186); + assertEq(GasUtils.estimateGas(0, 33, 0), 76559); + assertEq(GasUtils.estimateGas(33, 0, 0), 77219); + assertEq(GasUtils.estimateGas(20, 13, 0), 76959); UFloat9x56 one = UFloatMath.ONE; - assertEq(GasUtils.estimateWeiCost(one, 0, 0, 0, 0), 76208); - assertEq(GasUtils.estimateWeiCost(one, 0, 0, 33, 0), 76369); - assertEq(GasUtils.estimateWeiCost(one, 0, 33, 0, 0), 77029); - assertEq(GasUtils.estimateWeiCost(one, 0, 20, 13, 0), 76769); + assertEq(GasUtils.estimateWeiCost(one, 0, 0, 0, 0), 76186); + assertEq(GasUtils.estimateWeiCost(one, 0, 0, 33, 0), 76559); + assertEq(GasUtils.estimateWeiCost(one, 0, 33, 0, 0), 77219); + assertEq(GasUtils.estimateWeiCost(one, 0, 20, 13, 0), 76959); UFloat9x56 two = UFloat9x56.wrap(0x8080000000000000); - assertEq(GasUtils.estimateWeiCost(two, 0, 0, 0, 0), 76208 * 2); - assertEq(GasUtils.estimateWeiCost(two, 0, 0, 33, 0), 76369 * 2); - assertEq(GasUtils.estimateWeiCost(two, 0, 33, 0, 0), 77029 * 2); - assertEq(GasUtils.estimateWeiCost(two, 0, 20, 13, 0), 76769 * 2); + assertEq(GasUtils.estimateWeiCost(two, 0, 0, 0, 0), 76186 * 2); + assertEq(GasUtils.estimateWeiCost(two, 0, 0, 33, 0), 76559 * 2); + assertEq(GasUtils.estimateWeiCost(two, 0, 33, 0, 0), 77219 * 2); + assertEq(GasUtils.estimateWeiCost(two, 0, 20, 13, 0), 76959 * 2); } } diff --git a/test/Gateway.t.sol b/test/Gateway.t.sol index d1bddc0..28e4254 100644 --- a/test/Gateway.t.sol +++ b/test/Gateway.t.sol @@ -100,7 +100,7 @@ library GatewayUtils { pure returns (uint256 baseCost, uint256 executionCost) { - (, executionCost) = GasUtils.executionGasCost(message.data.length); + executionCost = GasUtils.computeExecutionRefund(uint16(message.data.length), 0); bytes memory encodedCall = abi.encodeCall(IExecutor.execute, (signature, message)); baseCost = TestUtils.calculateBaseCost(encodedCall); } @@ -137,7 +137,7 @@ contract GatewayBase is Test { bytes32 private _srcDomainSeparator; bytes32 private _dstDomainSeparator; - uint256 private constant SUBMIT_GAS_COST = 6095 + 9206; + uint256 private constant SUBMIT_GAS_COST = 15034; uint16 private constant SRC_NETWORK_ID = 1234; uint16 internal constant DEST_NETWORK_ID = 1337; uint8 private constant GMP_STATUS_SUCCESS = 1; @@ -178,7 +178,7 @@ contract GatewayBase is Test { // See the file `HelperContract.opcode` for more details. { bytes memory bytecode = - hex"603b80600c6000396000f3fe5a600201803d523d60209160643560240135146018575bfd5b60345a116018575a604803565b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5bf3"; + hex"603c80600a5f395ff3fe5a600201803d523d60209160643560240135146018575bfd5b60365a116018575a604903565b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5bf3"; receiver = IGmpReceiver(TestUtils.deployContract(bytecode)); } } @@ -216,10 +216,10 @@ contract GatewayBase is Test { function test_estimateMessageCost() external { vm.txGasPrice(1); uint256 cost = gateway.estimateMessageCost(DEST_NETWORK_ID, 96, 100000); - assertEq(cost, 180028); + assertEq(cost, 178479); } - function test_gasMeter() external { + function test_checkPayloadSize() external { vm.txGasPrice(1); address sender = TestUtils.createTestAccount(100 ether); @@ -231,9 +231,52 @@ contract GatewayBase is Test { destNetwork: DEST_NETWORK_ID, gasLimit: 0, salt: 0, - data: hex"" + data: new bytes(24576 + 1) + }); + + Signature memory sig = sign(gmp); + + // Calculate memory expansion cost and base cost + (uint256 baseCost, uint256 executionCost) = GatewayUtils.computeGmpGasCost(sig, gmp); + + // Transaction Parameters + CallOptions memory ctx = CallOptions({ + from: sender, + to: address(gateway), + value: 0, + gasLimit: GasUtils.executionGasNeeded(gmp.data.length, gmp.gasLimit) + baseCost + 1_000_000, + executionCost: 0, + baseCost: 0 + }); + + GmpStatus status; + bytes32 returned; + + // Expect a revert + vm.expectRevert("msg data too large"); + (status, returned) = ctx.execute(sig, gmp); + assertLt(ctx.executionCost, executionCost, "revert should use less gas!!"); + assertEq(ctx.baseCost, baseCost, "unexpected base cost"); + } + + /** + * @dev Test the gas metering for the `execute` function. + */ + function test_gasMeter(uint16 messageSize) external { + vm.assume(messageSize < 1000); + vm.txGasPrice(1); + address sender = TestUtils.createTestAccount(100 ether); + + // Build and sign GMP message + GmpMessage memory gmp = GmpMessage({ + source: sender.toSender(false), + srcNetwork: SRC_NETWORK_ID, + dest: address(bytes20(keccak256("dummy_address"))), + destNetwork: DEST_NETWORK_ID, + gasLimit: 0, + salt: 0, + data: new bytes(messageSize) }); - // ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff Signature memory sig = sign(gmp); @@ -246,7 +289,6 @@ contract GatewayBase is Test { to: address(gateway), value: 0, gasLimit: GasUtils.executionGasNeeded(gmp.data.length, gmp.gasLimit) + baseCost - 1, - // gasLimit: 100_000, executionCost: 0, baseCost: 0 }); @@ -255,7 +297,6 @@ contract GatewayBase is Test { bytes32 returned; // Expect a revert - // vm.expectRevert("insufficient gas to execute GMP message"); vm.expectRevert(); (status, returned) = ctx.execute(sig, gmp); @@ -269,21 +310,101 @@ contract GatewayBase is Test { ctx.gasLimit += 1; (status, returned) = ctx.execute(sig, gmp); + assertEq(uint256(status), uint256(GmpStatus.SUCCESS), "gmp execution failed"); + assertEq(uint256(returned), gmp.gasLimit, "wrong gmp return value"); assertEq(ctx.baseCost, baseCost, "ctx.baseCost != baseCost"); assertEq(ctx.executionCost, executionCost, "ctx.executionCost != executionCost"); assertEq(gatewayBalance - address(gateway).balance, executionCost + baseCost, "wrong refund amount"); assertEq(senderBalance, address(sender).balance, "sender balance should not change"); + // Submit the transaction { uint256 nonZeros = GasUtils.countNonZeros(gmp.data); uint256 zeros = gmp.data.length - nonZeros; ctx.value = GasUtils.estimateGas(uint16(nonZeros), uint16(zeros), gmp.gasLimit) - 1; } + + // Must revert if fund are insufficient vm.expectRevert("insufficient tx value"); ctx.submitMessage(gmp); + { + bytes memory submitEncoded = + abi.encodeCall(IGateway.submitMessage, (gmp.dest, gmp.destNetwork, gmp.gasLimit, gmp.data)); + assertEq(submitEncoded.length, ((gmp.data.length + 31) & 0xffe0) + 164, "wrong encoded length"); + } + + // Must work if the funds are sufficient ctx.value += 1; + ctx.gasLimit += gmp.data.length * 8; + ctx.submitMessage(gmp); + + assertEq( + ctx.executionCost, + GasUtils.submitMessageGasCost(uint16(gmp.data.length)) - 4500, + "unexpected submit message gas cost" + ); + } + + function test_submitMessageMeter(uint16 messageSize) external { + vm.assume(messageSize < 0x6000); + vm.txGasPrice(1); + address sender = TestUtils.createTestAccount(1000 ether); + + // Build and sign GMP message + GmpMessage memory gmp = GmpMessage({ + source: sender.toSender(false), + srcNetwork: DEST_NETWORK_ID, + dest: address(bytes20(keccak256("dummy_address"))), + destNetwork: DEST_NETWORK_ID, + gasLimit: 0, + salt: 0, + data: new bytes(messageSize) + }); + + Signature memory sig = sign(gmp); + + // Calculate memory expansion cost and base cost + (uint256 baseCost,) = GatewayUtils.computeGmpGasCost(sig, gmp); + + // Transaction Parameters + CallOptions memory ctx = CallOptions({ + from: sender, + to: address(gateway), + value: 0, + gasLimit: GasUtils.executionGasNeeded(gmp.data.length, gmp.gasLimit) + baseCost, + executionCost: 0, + baseCost: 0 + }); + + // Submit the transaction + { + uint256 nonZeros = GasUtils.countNonZeros(gmp.data); + uint256 zeros = gmp.data.length - nonZeros; + ctx.value = GasUtils.estimateGas(uint16(nonZeros), uint16(zeros), gmp.gasLimit) - 1; + } + + // Must revert if fund are insufficient + vm.expectRevert("insufficient tx value"); ctx.submitMessage(gmp); + + // Must work if the funds are sufficient + ctx.value += 1; + ctx.gasLimit += gmp.data.length * 8; + bytes32 id = gmp.eip712TypedHash(_dstDomainSeparator); + vm.expectEmit(true, true, true, true); + emit IGateway.GmpCreated( + id, GmpSender.unwrap(gmp.source), gmp.dest, gmp.destNetwork, gmp.gasLimit, gmp.salt, gmp.data + ); + bytes32 returnedId = ctx.submitMessage(gmp); + assertEq(returnedId, id, "unexpected GMP id"); + + // Verify the execution cost + assertEq( + ctx.executionCost, + GasUtils.submitMessageGasCost(uint16(gmp.data.length)), + "unexpected submit message gas cost" + ); } function test_refund() external { @@ -291,7 +412,7 @@ contract GatewayBase is Test { GmpSender sender = TestUtils.createTestAccount(100 ether).toSender(false); // GMP message gas used - uint256 gmpGasUsed = 1_000; + uint256 gmpGasUsed = 2_000; // Build and sign GMP message GmpMessage memory gmp = GmpMessage({ @@ -305,38 +426,21 @@ contract GatewayBase is Test { }); Signature memory sig = sign(gmp); - // Deposit funds + // Estimate execution cost (uint256 baseCost, uint256 executionCost) = GatewayUtils.computeGmpGasCost(sig, gmp); uint256 expectGasUsed = baseCost + executionCost + gmp.gasLimit; - // Calculate memory expansion cost and base cost - // uint256 baseCost; - // { - // bytes memory encodedExecuteCall = abi.encodeCall(IExecutor.execute, (sig, gmp)); - // baseCost = TestUtils.calculateBaseCost(encodedExecuteCall); - // expectGasUsed += TestUtils.memExpansionCost(encodedExecuteCall.length); - // } - - // Deposit funds - // { - // GmpSender gmpSender = sender.toSender(false); - // assertEq(gateway.depositOf(gmpSender, DEST_NETWORK_ID), 0); - // vm.prank(sender, sender); - // gateway.deposit{value: expectGasUsed + baseCost}(gmpSender, DEST_NETWORK_ID); - // assertEq(gateway.depositOf(gmpSender, DEST_NETWORK_ID), expectGasUsed + baseCost); - // } - // Execute GMP message uint256 beforeBalance = sender.toAddress().balance; - CallOptions memory ctx = CallOptions({ - from: sender.toAddress(), - to: address(gateway), - value: 0, - gasLimit: expectGasUsed + 2160 + 785 + 10, - executionCost: 0, - baseCost: 0 - }); { + CallOptions memory ctx = CallOptions({ + from: sender.toAddress(), + to: address(gateway), + value: 0, + gasLimit: GasUtils.executionGasNeeded(gmp.data.length, gmp.gasLimit) + baseCost, + executionCost: 0, + baseCost: 0 + }); (GmpStatus status, bytes32 returned) = ctx.execute(sig, gmp); // Verify the GMP message status @@ -349,6 +453,7 @@ contract GatewayBase is Test { // Verify the gas cost assertEq(ctx.executionCost + ctx.baseCost, expectGasUsed, "unexpected gas used"); + assertEq(ctx.executionCost, executionCost + gmp.gasLimit, "unexpected execution cost"); } // Verify the gas refund @@ -489,12 +594,12 @@ contract GatewayBase is Test { id, GmpSender.unwrap(gmp.source), gmp.dest, gmp.destNetwork, gmp.gasLimit, gmp.salt, gmp.data ); - // Submit message + // Submit message with sufficient funds ctx.value += 1; - ctx.submitMessage(gmp); + assertEq(ctx.submitMessage(gmp), id, "unexpected GMP id"); // Verify the gas cost - uint256 expectedCost = SUBMIT_GAS_COST + 2800 + 2000 + 2000; + uint256 expectedCost = GasUtils.submitMessageGasCost(uint16(gmp.data.length)) - 6500; assertEq(ctx.executionCost, expectedCost, "unexpected execution gas cost"); // Now the second GMP message should have the salt equals to previous gmp hash @@ -506,13 +611,8 @@ contract GatewayBase is Test { emit IGateway.GmpCreated( id, GmpSender.unwrap(gmp.source), gmp.dest, gmp.destNetwork, gmp.gasLimit, gmp.salt, gmp.data ); - ctx.submitMessage(gmp); - - if (ctx.baseCost > 0) { - return; - } - expectedCost = SUBMIT_GAS_COST; - assertEq(ctx.executionCost, expectedCost, "unexpected execution gas cost"); + assertEq(ctx.submitMessage(gmp), id, "unexpected GMP id"); + assertEq(ctx.executionCost, expectedCost - 6800, "unexpected execution gas cost"); } } diff --git a/test/GmpTestTools.sol b/test/GmpTestTools.sol index ef72c5c..f13c214 100644 --- a/test/GmpTestTools.sol +++ b/test/GmpTestTools.sol @@ -236,7 +236,7 @@ library GmpTestTools { bytes32 slot = _deriveMapping(bytes32(0), shard.pubkey.px); uint256 shardInfo = uint256(vm.load(gateway, slot)); uint256 nonce = shardInfo >> 224; - nonce = BranchlessMath.select(nonce > 0, nonce, 1); + nonce = BranchlessMath.ternary(nonce > 0, nonce, 1); shardInfo = (nonce << 224) | (1 << 216) | ((shard.pubkey.py % 2) << 217); vm.store(gateway, slot, bytes32(shardInfo)); } diff --git a/test/TestUtils.sol b/test/TestUtils.sol index 93d3520..cd52c04 100644 --- a/test/TestUtils.sol +++ b/test/TestUtils.sol @@ -6,6 +6,7 @@ pragma solidity >=0.8.0; import {VmSafe, Vm} from "forge-std/Vm.sol"; import {Schnorr} from "@frost-evm/Schnorr.sol"; import {SECP256K1} from "@frost-evm/SECP256K1.sol"; +import {BranchlessMath} from "../src/utils/BranchlessMath.sol"; struct VerifyingKey { uint256 px; @@ -21,6 +22,8 @@ struct SigningKey { * @dev Utilities for testing purposes */ library TestUtils { + using BranchlessMath for uint256; + // Cheat code address, 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D. address internal constant VM_ADDRESS = address(uint160(uint256(keccak256("hevm cheat code")))); @@ -196,7 +199,7 @@ library TestUtils { private returns (uint256 gasUsed, bool success, bytes memory out) { - require(gasleft() > (gasLimit + 5000), "insufficient gas"); + require(gasleft() > gasLimit.saturatingAdd(5000), "insufficient gas"); require(addr.code.length > 0, "Not a contract address"); /// @solidity memory-safe-assembly assembly { @@ -231,8 +234,8 @@ library TestUtils { // Decrement sender base cost { - uint256 txFees = (baseCost + gasLimit) * tx.gasprice; - require(sender.balance >= (txFees + value), "account has no sufficient funds"); + uint256 txFees = baseCost.saturatingAdd(gasLimit).saturatingMul(tx.gasprice); + require(sender.balance >= txFees.saturatingAdd(value), "account has no sufficient funds"); vm.deal(sender, sender.balance - txFees); } @@ -241,12 +244,12 @@ library TestUtils { { (VmSafe.CallerMode callerMode, address msgSender, address txOrigin) = setCallerMode(VmSafe.CallerMode.RecurrentPrank, sender, sender); - (executionCost, success, out) = _call(dest, gasLimit - baseCost, value, data); + (executionCost, success, out) = _call(dest, gasLimit.saturatingSub(baseCost), value, data); setCallerMode(callerMode, msgSender, txOrigin); } // Refund unused gas - uint256 refund = (gasLimit - executionCost) * tx.gasprice; + uint256 refund = gasLimit.saturatingSub(executionCost).saturatingMul(tx.gasprice); if (refund > 0) { vm.deal(sender, sender.balance + refund); }