Skip to content

Commit

Permalink
Extensible Fallback Handler (#851)
Browse files Browse the repository at this point in the history
This PR brings the `ExtensibleFallbackHandler` created
[here](https://github.com/rndlabs/safe-contracts/blob/merged-efh-sigmuxer/contracts/handler/ExtensibleFallbackHandler.sol)
to `safe-smart-account` repo. Changes taken based on [git
diff](main...rndlabs:safe-contracts:main).

Some small changes were made like:
- Adapting tests to the `safe-smart-account` repo.
- Remove unused import.
- Replace global importing (based on our lint setup) with specific
contracts required.
  • Loading branch information
remedcu authored Oct 29, 2024
1 parent b55fd8f commit 76ea23d
Show file tree
Hide file tree
Showing 15 changed files with 1,501 additions and 0 deletions.
28 changes: 28 additions & 0 deletions contracts/handler/ExtensibleFallbackHandler.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {ERC165Handler} from "./extensible/ERC165Handler.sol";
import {IFallbackHandler, FallbackHandler} from "./extensible/FallbackHandler.sol";
import {ERC1271, ISignatureVerifierMuxer, SignatureVerifierMuxer} from "./extensible/SignatureVerifierMuxer.sol";
import {ERC721TokenReceiver, ERC1155TokenReceiver, TokenCallbacks} from "./extensible/TokenCallbacks.sol";

/**
* @title ExtensibleFallbackHandler - A fully extensible fallback handler for Safes
* @dev Designed to be used with Safe >= 1.3.0.
* @author mfw78 <[email protected]>
*/
contract ExtensibleFallbackHandler is FallbackHandler, SignatureVerifierMuxer, TokenCallbacks, ERC165Handler {
/**
* Specify specific interfaces (ERC721 + ERC1155) that this contract supports.
* @param interfaceId The interface ID to check for support
*/
function _supportsInterface(bytes4 interfaceId) internal pure override returns (bool) {
return
interfaceId == type(ERC1271).interfaceId ||
interfaceId == type(ISignatureVerifierMuxer).interfaceId ||
interfaceId == type(ERC165Handler).interfaceId ||
interfaceId == type(IFallbackHandler).interfaceId ||
interfaceId == type(ERC721TokenReceiver).interfaceId ||
interfaceId == type(ERC1155TokenReceiver).interfaceId;
}
}
1 change: 1 addition & 0 deletions contracts/handler/HandlerContext.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ abstract contract HandlerContext {
* @return sender Original caller address.
*/
function _msgSender() internal pure returns (address sender) {
require(msg.data.length >= 20, "Invalid calldata length");
// The assembly code is more direct than the Solidity version using `abi.decode`.
/* solhint-disable no-inline-assembly */
/// @solidity memory-safe-assembly
Expand Down
113 changes: 113 additions & 0 deletions contracts/handler/extensible/ERC165Handler.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {IERC165} from "../../interfaces/IERC165.sol";
import {Safe, MarshalLib, ExtensibleBase} from "./ExtensibleBase.sol";

interface IERC165Handler {
function safeInterfaces(Safe safe, bytes4 interfaceId) external view returns (bool);

function setSupportedInterface(bytes4 interfaceId, bool supported) external;

function addSupportedInterfaceBatch(bytes4 interfaceId, bytes32[] calldata handlerWithSelectors) external;

function removeSupportedInterfaceBatch(bytes4 interfaceId, bytes4[] calldata selectors) external;
}

abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
// --- events ---

event AddedInterface(Safe indexed safe, bytes4 interfaceId);
event RemovedInterface(Safe indexed safe, bytes4 interfaceId);

// --- storage ---

mapping(Safe => mapping(bytes4 => bool)) public override safeInterfaces;

// --- setters ---

/**
* Setter to indicate if an interface is supported (and thus reported by ERC165 supportsInterface)
* @param interfaceId The interface id whose support is to be set
* @param supported True if the interface is supported, false otherwise
*/
function setSupportedInterface(bytes4 interfaceId, bool supported) public override onlySelf {
Safe safe = Safe(payable(_manager()));
// invalid interface id per ERC165 spec
require(interfaceId != 0xffffffff, "invalid interface id");
bool current = safeInterfaces[safe][interfaceId];
if (supported && !current) {
safeInterfaces[safe][interfaceId] = true;
emit AddedInterface(safe, interfaceId);
} else if (!supported && current) {
delete safeInterfaces[safe][interfaceId];
emit RemovedInterface(safe, interfaceId);
}
}

/**
* Batch add selectors for an interface.
* @param _interfaceId The interface id to set
* @param handlerWithSelectors The handlers encoded with the 4-byte selectors of the methods
*/
function addSupportedInterfaceBatch(bytes4 _interfaceId, bytes32[] calldata handlerWithSelectors) external override onlySelf {
Safe safe = Safe(payable(_msgSender()));
bytes4 interfaceId;
for (uint256 i = 0; i < handlerWithSelectors.length; i++) {
(bool isStatic, bytes4 selector, address handlerAddress) = MarshalLib.decodeWithSelector(handlerWithSelectors[i]);
_setSafeMethod(safe, selector, MarshalLib.encode(isStatic, handlerAddress));
if (i > 0) {
interfaceId ^= selector;
} else {
interfaceId = selector;
}
}

require(interfaceId == _interfaceId, "interface id mismatch");
setSupportedInterface(_interfaceId, true);
}

/**
* Batch remove selectors for an interface.
* @param _interfaceId the interface id to remove
* @param selectors The selectors of the methods to remove
*/
function removeSupportedInterfaceBatch(bytes4 _interfaceId, bytes4[] calldata selectors) external override onlySelf {
Safe safe = Safe(payable(_msgSender()));
bytes4 interfaceId;
for (uint256 i = 0; i < selectors.length; i++) {
_setSafeMethod(safe, selectors[i], bytes32(0));
if (i > 0) {
interfaceId ^= selectors[i];
} else {
interfaceId = selectors[i];
}
}

require(interfaceId == _interfaceId, "interface id mismatch");
setSupportedInterface(_interfaceId, false);
}

/**
* @notice Implements ERC165 interface detection for the supported interfaces
* @dev Inheriting contracts should override `_supportsInterface` to add support for additional interfaces
* @param interfaceId The ERC165 interface id to check
* @return True if the interface is supported
*/
function supportsInterface(bytes4 interfaceId) external view returns (bool) {
return
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IERC165Handler).interfaceId ||
_supportsInterface(interfaceId) ||
safeInterfaces[Safe(payable(_manager()))][interfaceId];
}

// --- internal ---

/**
* A stub function to be overridden by inheriting contracts to add support for additional interfaces
* @param interfaceId The interface id to check support for
* @return True if the interface is supported
*/
function _supportsInterface(bytes4 interfaceId) internal view virtual returns (bool);
}
86 changes: 86 additions & 0 deletions contracts/handler/extensible/ExtensibleBase.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {Safe} from "../../Safe.sol";
import {HandlerContext} from "../HandlerContext.sol";
import {MarshalLib} from "./MarshalLib.sol";

interface IFallbackMethod {
function handle(Safe safe, address sender, uint256 value, bytes calldata data) external returns (bytes memory result);
}

interface IStaticFallbackMethod {
function handle(Safe safe, address sender, uint256 value, bytes calldata data) external view returns (bytes memory result);
}

/**
* @title Base contract for Extensible Fallback Handlers
* @dev This contract provides the base for storage and modifiers for extensible fallback handlers
* @author mfw78 <[email protected]>
*/
abstract contract ExtensibleBase is HandlerContext {
// --- events ---
event AddedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 method);
event ChangedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 oldMethod, bytes32 newMethod);
event RemovedSafeMethod(Safe indexed safe, bytes4 selector);

// --- storage ---

// A mapping of Safe => selector => method
// The method is a bytes32 that is encoded as follows:
// - The first byte is 0x00 if the method is static and 0x01 if the method is not static
// - The last 20 bytes are the address of the handler contract
// The method is encoded / decoded using the MarshalLib
mapping(Safe => mapping(bytes4 => bytes32)) public safeMethods;

// --- modifiers ---
modifier onlySelf() {
// Use the `HandlerContext._msgSender()` to get the caller of the fallback function
// Use the `HandlerContext._manager()` to get the manager, which should be the Safe
// Require that the caller is the Safe itself
require(_msgSender() == _manager(), "only safe can call this method");
_;
}

// --- internal ---

function _setSafeMethod(Safe safe, bytes4 selector, bytes32 newMethod) internal {
(, address newHandler) = MarshalLib.decode(newMethod);
bytes32 oldMethod = safeMethods[safe][selector];
(, address oldHandler) = MarshalLib.decode(oldMethod);

if (address(newHandler) == address(0) && address(oldHandler) != address(0)) {
delete safeMethods[safe][selector];
emit RemovedSafeMethod(safe, selector);
} else {
safeMethods[safe][selector] = newMethod;
if (address(oldHandler) == address(0)) {
emit AddedSafeMethod(safe, selector, newMethod);
} else {
emit ChangedSafeMethod(safe, selector, oldMethod, newMethod);
}
}
}

/**
* Dry code to get the Safe and the original `msg.sender` from the FallbackManager
* @return safe The safe whose FallbackManager is making this call
* @return sender The original `msg.sender` (as received by the FallbackManager)
*/
function _getContext() internal view returns (Safe safe, address sender) {
safe = Safe(payable(_manager()));
sender = _msgSender();
}

/**
* Get the context and the method handler applicable to the current call
* @return safe The safe whose FallbackManager is making this call
* @return sender The original `msg.sender` (as received by the FallbackManager)
* @return isStatic Whether the method is static (`view`) or not
* @return handler the address of the handler contract
*/
function _getContextAndHandler() internal view returns (Safe safe, address sender, bool isStatic, address handler) {
(safe, sender) = _getContext();
(isStatic, handler) = MarshalLib.decode(safeMethods[safe][msg.sig]);
}
}
42 changes: 42 additions & 0 deletions contracts/handler/extensible/FallbackHandler.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {Safe, IStaticFallbackMethod, IFallbackMethod, ExtensibleBase} from "./ExtensibleBase.sol";

interface IFallbackHandler {
function setSafeMethod(bytes4 selector, bytes32 newMethod) external;
}

/**
* @title FallbackHandler - A fully extensible fallback handler for Safes
* @dev This contract provides a fallback handler for Safes that can be extended with custom fallback handlers
* for specific methods.
* @author mfw78 <[email protected]>
*/
abstract contract FallbackHandler is ExtensibleBase, IFallbackHandler {
// --- setters ---

/**
* Setter for custom method handlers
* @param selector The `bytes4` selector of the method to set the handler for
* @param newMethod A contract that implements the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function setSafeMethod(bytes4 selector, bytes32 newMethod) public override onlySelf {
_setSafeMethod(Safe(payable(_msgSender())), selector, newMethod);
}

// --- fallback ---

// solhint-disable-next-line
fallback(bytes calldata) external returns (bytes memory result) {
require(msg.data.length >= 24, "invalid method selector");
(Safe safe, address sender, bool isStatic, address handler) = _getContextAndHandler();
require(handler != address(0), "method handler not set");

if (isStatic) {
result = IStaticFallbackMethod(handler).handle(safe, sender, 0, msg.data[:msg.data.length - 20]);
} else {
result = IFallbackMethod(handler).handle(safe, sender, 0, msg.data[:msg.data.length - 20]);
}
}
}
60 changes: 60 additions & 0 deletions contracts/handler/extensible/MarshalLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

library MarshalLib {
/**
* Encode a method handler into a `bytes32` value
* @dev The first byte of the `bytes32` value is set to 0x01 if the method is not static (`view`)
* @dev The last 20 bytes of the `bytes32` value are set to the address of the handler contract
* @param isStatic Whether the method is static (`view`) or not
* @param handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function encode(bool isStatic, address handler) internal pure returns (bytes32 data) {
data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248)));
}

/**
* Encode a method handler into a `bytes32` value with a selector
* @dev The first byte of the `bytes32` value is set to 0x01 if the method is not static (`view`)
* @dev The next 4 bytes of the `bytes32` value are set to the selector of the method
* @dev The last 20 bytes of the `bytes32` value are set to the address of the handler contract
* @param isStatic Whether the method is static (`view`) or not
* @param selector The selector of the method
* @param handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function encodeWithSelector(bool isStatic, bytes4 selector, address handler) internal pure returns (bytes32 data) {
data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248)) | (uint256(uint32(selector)) << 216));
}

/**
* Given a `bytes32` value, decode it into a method handler and return it
* @param data The packed data to decode
* @return isStatic Whether the method is static (`view`) or not
* @return handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function decode(bytes32 data) internal pure returns (bool isStatic, address handler) {
// solhint-disable-next-line no-inline-assembly
assembly {
// set isStatic to true if the left-most byte of the data is 0x00
isStatic := iszero(shr(248, data))
handler := shr(96, shl(96, data))
}
}

/**
* Given a `bytes32` value, decode it into a method handler and return it
* @param data The packed data to decode
* @return isStatic Whether the method is static (`view`) or not
* @return selector The selector of the method
* @return handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function decodeWithSelector(bytes32 data) internal pure returns (bool isStatic, bytes4 selector, address handler) {
// solhint-disable-next-line no-inline-assembly
assembly {
// set isStatic to true if the left-most byte of the data is 0x00
isStatic := iszero(shr(248, data))
handler := shr(96, shl(96, data))
selector := shl(168, shr(160, data))
}
}
}
Loading

0 comments on commit 76ea23d

Please sign in to comment.