Skip to content

Commit

Permalink
Merge pull request #183 from bcnmy/fix/hook-fallback-issue-58
Browse files Browse the repository at this point in the history
Fix hooking the `fallback()`
  • Loading branch information
livingrockrises authored Sep 26, 2024
2 parents c81e30c + 27b8627 commit a98c93f
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 64 deletions.
107 changes: 53 additions & 54 deletions contracts/base/ModuleManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { IExecutor } from "../interfaces/modules/IExecutor.sol";
import { IFallback } from "../interfaces/modules/IFallback.sol";
import { IValidator } from "../interfaces/modules/IValidator.sol";
import { CallType, CALLTYPE_SINGLE, CALLTYPE_STATIC } from "../lib/ModeLib.sol";
import { ExecLib } from "../lib/ExecLib.sol";
import { LocalCallDataParserLib } from "../lib/local/LocalCallDataParserLib.sol";
import { IModuleManagerEventsAndErrors } from "../interfaces/base/IModuleManagerEventsAndErrors.sol";
import { MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR, MODULE_TYPE_FALLBACK, MODULE_TYPE_HOOK, MODULE_TYPE_MULTI, MODULE_ENABLE_MODE_TYPE_HASH, ERC1271_MAGICVALUE } from "contracts/types/Constants.sol";
Expand All @@ -39,6 +40,8 @@ import { SentinelListLib } from "sentinellist/src/SentinelList.sol";
abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndErrors, RegistryAdapter {
using SentinelListLib for SentinelListLib.SentinelList;
using LocalCallDataParserLib for bytes;
using ExecLib for address;
using ExcessivelySafeCall for address;

/// @notice Ensures the message sender is a registered executor module.
modifier onlyExecutorModule() virtual {
Expand All @@ -62,56 +65,8 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
receive() external payable {}

/// @dev Fallback function to manage incoming calls using designated handlers based on the call type.
fallback() external payable withHook {
FallbackHandler storage $fallbackHandler = _getAccountStorage().fallbacks[msg.sig];
address handler = $fallbackHandler.handler;
CallType calltype = $fallbackHandler.calltype;
if (handler != address(0)) {
if (calltype == CALLTYPE_STATIC) {
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

if iszero(staticcall(gas(), handler, 0, add(calldatasize(), 20), 0, 0)) {
returndatacopy(0, 0, returndatasize())
revert(0, returndatasize())
}
returndatacopy(0, 0, returndatasize())
return(0, returndatasize())
}
}
if (calltype == CALLTYPE_SINGLE) {
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

if iszero(call(gas(), handler, callvalue(), 0, add(calldatasize(), 20), 0, 0)) {
returndatacopy(0, 0, returndatasize())
revert(0, returndatasize())
}
returndatacopy(0, 0, returndatasize())
return(0, returndatasize())
}
}
}
/// @solidity memory-safe-assembly
assembly {
let s := shr(224, calldataload(0))
// 0x150b7a02: `onERC721Received(address,address,uint256,bytes)`.
// 0xf23a6e61: `onERC1155Received(address,address,uint256,uint256,bytes)`.
// 0xbc197c81: `onERC1155BatchReceived(address,address,uint256[],uint256[],bytes)`.
if or(eq(s, 0x150b7a02), or(eq(s, 0xf23a6e61), eq(s, 0xbc197c81))) {
mstore(0x20, s) // Store `msg.sig`.
return(0x3c, 0x20) // Return `msg.sig`.
}
}
revert MissingFallbackHandler(msg.sig);
fallback(bytes calldata callData) external payable withHook returns (bytes memory) {
return _fallback(callData);
}

/// @dev Retrieves a paginated list of validator addresses from the linked list.
Expand Down Expand Up @@ -228,7 +183,7 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
// Sentinel pointing to itself means the list is empty, so check this after removal
// Below error is very specific to uninstalling validators.
require(_hasValidators(), CanNotRemoveLastValidator());
ExcessivelySafeCall.excessivelySafeCall(validator, gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, disableModuleData));
validator.excessivelySafeCall(gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, disableModuleData));
}

/// @dev Installs a new executor module after checking if it matches the required module type.
Expand All @@ -246,7 +201,7 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
function _uninstallExecutor(address executor, bytes calldata data) internal virtual {
(address prev, bytes memory disableModuleData) = abi.decode(data, (address, bytes));
_getAccountStorage().executors.pop(prev, executor);
ExcessivelySafeCall.excessivelySafeCall(executor, gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, disableModuleData));
executor.excessivelySafeCall(gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, disableModuleData));
}

/// @dev Installs a hook module, ensuring no other hooks are installed before proceeding.
Expand All @@ -265,7 +220,7 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
/// @param data De-initialization data to configure the hook upon uninstallation.
function _uninstallHook(address hook, bytes calldata data) internal virtual {
_setHook(address(0));
ExcessivelySafeCall.excessivelySafeCall(hook, gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, data));
hook.excessivelySafeCall(gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, data));
}

/// @dev Sets the current hook in the storage to the specified address.
Expand Down Expand Up @@ -316,7 +271,7 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
/// @param data The de-initialization data containing the selector.
function _uninstallFallbackHandler(address fallbackHandler, bytes calldata data) internal virtual {
_getAccountStorage().fallbacks[bytes4(data[0:4])] = FallbackHandler(address(0), CallType.wrap(0x00));
ExcessivelySafeCall.excessivelySafeCall(fallbackHandler, gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, data[4:]));
fallbackHandler.excessivelySafeCall(gasleft(), 0, 0, abi.encodeWithSelector(IModule.onUninstall.selector, data[4:]));
}

/// @notice Installs a module with multiple types in a single operation.
Expand Down Expand Up @@ -470,6 +425,50 @@ abstract contract ModuleManager is Storage, EIP712, IModuleManagerEventsAndError
hook = address(_getAccountStorage().hook);
}

function _fallback(bytes calldata callData) private returns (bytes memory result) {
bool success;
FallbackHandler storage $fallbackHandler = _getAccountStorage().fallbacks[msg.sig];
address handler = $fallbackHandler.handler;
CallType calltype = $fallbackHandler.calltype;

if (handler != address(0)) {
//if there's a fallback handler, call it
if (calltype == CALLTYPE_STATIC) {
(success, result) = handler.staticcall(ExecLib.get2771CallData(callData));

Check warning on line 437 in contracts/base/ModuleManager.sol

View check run for this annotation

Codecov / codecov/patch

contracts/base/ModuleManager.sol#L437

Added line #L437 was not covered by tests
} else if (calltype == CALLTYPE_SINGLE) {
(success, result) = handler.call{ value: msg.value }(ExecLib.get2771CallData(callData));

Check warning on line 439 in contracts/base/ModuleManager.sol

View check run for this annotation

Codecov / codecov/patch

contracts/base/ModuleManager.sol#L439

Added line #L439 was not covered by tests
} else {
revert UnsupportedCallType(calltype);
}

// Use revert message from fallback handler if the call was not successful
if (!success) {
assembly {

Check warning on line 446 in contracts/base/ModuleManager.sol

View check run for this annotation

Codecov / codecov/patch

contracts/base/ModuleManager.sol#L446

Added line #L446 was not covered by tests
revert(add(result, 0x20), mload(result))
}
}
} else {
// If there's no handler, the call can be one of onERCXXXReceived()
bytes32 s;
/// @solidity memory-safe-assembly
assembly {
s := shr(224, calldataload(0))
// 0x150b7a02: `onERC721Received(address,address,uint256,bytes)`.
// 0xf23a6e61: `onERC1155Received(address,address,uint256,uint256,bytes)`.

Check warning on line 457 in contracts/base/ModuleManager.sol

View check run for this annotation

Codecov / codecov/patch

contracts/base/ModuleManager.sol#L457

Added line #L457 was not covered by tests
// 0xbc197c81: `onERC1155BatchReceived(address,address,uint256[],uint256[],bytes)`.
if or(eq(s, 0x150b7a02), or(eq(s, 0xf23a6e61), eq(s, 0xbc197c81))) {
success := true // it is one of onERCXXXReceived
result := mload(0x40) //result was set to 0x60 as it was empty, so we need to find a new space for it
mstore(result, 0x04) //store length
mstore(add(result, 0x20), shl(224, s)) //store calldata
mstore(0x40, add(result, 0x24)) //allocate memory
}
}
// if there was no handler and it is not the onERCXXXReceived call, revert
require(success, MissingFallbackHandler(msg.sig));
}
}

/// @dev Helper function to paginate entries in a SentinelList.
/// @param list The SentinelList to paginate.
/// @param cursor The cursor to start paginating from.
Expand Down
5 changes: 0 additions & 5 deletions contracts/interfaces/INexusEventsAndErrors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pragma solidity ^0.8.27;
// Nexus: A suite of contracts for Modular Smart Accounts compliant with ERC-7579 and ERC-4337, developed by Biconomy.
// Learn more at https://biconomy.io. To report security issues, please contact us at: [email protected]

import { CallType } from "../lib/ModeLib.sol";
import { PackedUserOperation } from "account-abstraction/contracts/interfaces/PackedUserOperation.sol";

/// @title Nexus - INexus Events and Errors
Expand All @@ -32,10 +31,6 @@ interface INexusEventsAndErrors {
/// @param moduleTypeId The ID of the unsupported module type.
error UnsupportedModuleType(uint256 moduleTypeId);

/// @notice Error thrown when an execution with an unsupported CallType was made.
/// @param callType The unsupported call type.
error UnsupportedCallType(CallType callType);

/// @notice Error thrown on failed execution.
error ExecutionFailed();

Expand Down
6 changes: 6 additions & 0 deletions contracts/interfaces/base/IModuleManagerEventsAndErrors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pragma solidity ^0.8.27;
// Nexus: A suite of contracts for Modular Smart Accounts compliant with ERC-7579 and ERC-4337, developed by Biconomy.
// Learn more at https://biconomy.io. To report security issues, please contact us at: [email protected]

import { CallType } from "../../lib/ModeLib.sol";

/// @title ERC-7579 Module Manager Events and Errors Interface
/// @notice Provides event and error definitions for actions related to module management in smart accounts.
/// @dev Used by IModuleManager to define the events and errors associated with the installation and management of modules.
Expand Down Expand Up @@ -90,4 +92,8 @@ interface IModuleManagerEventsAndErrors {

/// @dev Thrown when there is an attempt to install a fallback handler with an invalid calltype for a given selector.
error FallbackCallTypeInvalid();

/// @notice Error thrown when an execution with an unsupported CallType was made.
/// @param callType The unsupported call type.
error UnsupportedCallType(CallType callType);
}
21 changes: 21 additions & 0 deletions contracts/lib/ExecLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@ import { Execution } from "../types/DataTypes.sol";
/// Helper Library for decoding Execution calldata
/// malloc for memory allocation is bad for gas. use this assembly instead
library ExecLib {
function get2771CallData(bytes calldata cd) internal view returns (bytes memory callData) {
/// @solidity memory-safe-assembly
(cd);
assembly {
// as per solidity docs
function allocate(length) -> pos {
pos := mload(0x40)
mstore(0x40, add(pos, length))
}

callData := allocate(add(calldatasize(), 0x20)) //allocate extra 0x20 to store length
mstore(callData, add(calldatasize(), 0x14)) //store length, extra 0x14 is for msg.sender address
calldatacopy(add(callData, 0x20), 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
let senderPtr := allocate(0x14)
mstore(senderPtr, shl(96, caller()))
}
}

function decodeBatch(bytes calldata callData) internal pure returns (Execution[] calldata executionBatch) {
/*
* Batch Call Calldata Layout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ contract TestModuleManager_FallbackHandler is TestModuleManagement_Base {
function setUp() public {
init();

Execution[] memory execution = new Execution[](2);

// Custom data for installing the MockHandler with call type STATIC
bytes memory customData = abi.encode(bytes5(abi.encodePacked(GENERIC_FALLBACK_SELECTOR, CALLTYPE_SINGLE)));

Expand All @@ -22,13 +24,23 @@ contract TestModuleManager_FallbackHandler is TestModuleManagement_Base {
customData
);

Execution[] memory execution = new Execution[](1);
execution[0] = Execution(address(BOB_ACCOUNT), 0, callData);

callData = abi.encodeWithSelector(
IModuleManager.installModule.selector,
MODULE_TYPE_HOOK,
address(HOOK_MODULE),
""
);

execution[1] = Execution(address(BOB_ACCOUNT), 0, callData);

PackedUserOperation[] memory userOps = buildPackedUserOperation(BOB, BOB_ACCOUNT, EXECTYPE_DEFAULT, execution, address(VALIDATOR_MODULE));
ENTRYPOINT.handleOps(userOps, payable(address(BOB.addr)));

// Verify the fallback handler was installed
assertEq(BOB_ACCOUNT.isModuleInstalled(MODULE_TYPE_FALLBACK, address(HANDLER_MODULE), customData), true, "Fallback handler not installed");
assertEq(BOB_ACCOUNT.isModuleInstalled(MODULE_TYPE_HOOK, address(HOOK_MODULE), ""), true, "Hook not installed");
}

/// @notice Tests triggering the onGenericFallback function of the fallback handler.
Expand All @@ -47,7 +59,7 @@ contract TestModuleManager_FallbackHandler is TestModuleManagement_Base {
}

/// @notice Tests that handleOps triggers the generic fallback handler.
function test_HandleOpsTriggersGenericFallback() public {
function test_HandleOpsTriggersGenericFallback(bool skip) public {
// Prepare the operation that triggers the fallback handler
bytes memory dataToTriggerFallback = abi.encodeWithSelector(
MockHandler(address(0)).onGenericFallback.selector,
Expand All @@ -61,14 +73,26 @@ contract TestModuleManager_FallbackHandler is TestModuleManagement_Base {
// Prepare UserOperation
PackedUserOperation[] memory userOps = buildPackedUserOperation(BOB, BOB_ACCOUNT, EXECTYPE_DEFAULT, executions, address(VALIDATOR_MODULE));

// Expect the GenericFallbackCalled event from the MockHandler contract
vm.expectEmit(true, true, false, true, address(HANDLER_MODULE));
emit GenericFallbackCalled(address(this), 123, "Example data");
if (!skip) {
// Expect the GenericFallbackCalled event from the MockHandler contract
vm.expectEmit(true, true, false, true, address(HANDLER_MODULE));
emit GenericFallbackCalled(address(this), 123, "Example data");
}

// Call handleOps, which should trigger the fallback handler and emit the event
ENTRYPOINT.handleOps(userOps, payable(address(BOB.addr)));
}

/// @notice Tests that handleOps triggers the generic fallback handler.
function test_HandleOpsTriggersGenericFallback_IsProperlyHooked() public {
vm.expectEmit(address(HOOK_MODULE));
emit PreCheckCalled();
vm.expectEmit(address(HOOK_MODULE));
emit PostCheckCalled();
// skip fallback emit check as per Matching Sequences section here => https://book.getfoundry.sh/cheatcodes/expect-emit
test_HandleOpsTriggersGenericFallback({skip: true});
}

/// @notice Tests installing a fallback handler.
/// @param selector The function selector for the fallback handler.
function test_InstallFallbackHandler(bytes4 selector) internal {
Expand Down Expand Up @@ -264,4 +288,22 @@ contract TestModuleManager_FallbackHandler is TestModuleManagement_Base {

ENTRYPOINT.handleOps(userOps, payable(address(BOB.addr)));
}

function test_onTokenReceived_Success() public {
vm.startPrank(address(ENTRYPOINT));
//ERC-721
(bool success, bytes memory data) = address(BOB_ACCOUNT).call{value: 0}(hex'150b7a02');
assertTrue(success);
assertTrue(keccak256(data) == keccak256(bytes(hex'150b7a02')));
//ERC-1155
(success, data) = address(BOB_ACCOUNT).call{value: 0}(hex'f23a6e61');
assertTrue(success);
assertTrue(keccak256(data) == keccak256(bytes(hex'f23a6e61')));
//ERC-1155 Batch
(success, data) = address(BOB_ACCOUNT).call{value: 0}(hex'bc197c81');
assertTrue(success);
assertTrue(keccak256(data) == keccak256(bytes(hex'bc197c81')));

vm.stopPrank();
}
}

0 comments on commit a98c93f

Please sign in to comment.