From 0f3ee326009709a705eb560712bad070b52d8890 Mon Sep 17 00:00:00 2001 From: Akshay Date: Mon, 30 Oct 2023 12:07:05 +0100 Subject: [PATCH] Add SignatureValidatorManager contract (#109) * [#108] Create SignatureValidatorManager contract * [#108] SignatureValidatorManager to be used as a function handler * wip: tests * [#108] Implement signature validation flow * [#108] Add validator hooks * [#108] Add test for signature validation with hooks * [#108] Add tests for signature validator flow * [#108] Add tests for signature validator, minor fix in Registry * [#108] Rename ISafeProtocol712SignatureValidator to ISafeProtocolSignatureValidator, Update natspec doc * [#108] Define signature selector as const, reorg signature selector tests * [#108] Document layout of signatures * [#108] Validate messageHash * [#108] Add mocks and tests for signature validators with hooks * [#108] Add tests for default validation flow, minor test updates * [#108] Fix lint issues * [#108] Update natspec for events * [#108] Remove ISafeAccount interface * [#108] Add domain separator in default validation flow * [#108] Fix typo * [#108] Use number as type for constants * [#108] Use if-else in deciding validator routing * [#108] use hash in defaultValidator instead of hash of hash * [#108] Use Signature Validator Manager's domain separator and type hash for default valdiation * Update contracts/SignatureValidatorManager.sol Co-authored-by: Nicholas Rodrigues Lordello * [#108] Fix lint issue * [#108] Define preise error reason, fix tests * [#108] Create function that checks if contract supports interface in Registry * [#108] SignatureMalidatorManager handle(...) function is view --------- Co-authored-by: Nicholas Rodrigues Lordello --- contracts/SafeProtocolRegistry.sol | 61 +-- contracts/SignatureValidatorManager.sol | 231 +++++++++++ contracts/common/Constants.sol | 2 + contracts/interfaces/Accounts.sol | 2 + contracts/interfaces/Manager.sol | 13 + contracts/interfaces/Modules.sol | 37 ++ contracts/test/TestExecutor.sol | 4 + deploy/deploy_protocol.ts | 7 + src/utils/constants.ts | 20 +- test/SafeProtocolManager.spec.ts | 16 +- test/SafeProtocolRegistry.spec.ts | 63 ++- test/SignatureValidatorManager.spec.ts | 503 +++++++++++++++++++++++ test/utils/contracts.ts | 12 + test/utils/mockValidationHooksBuilder.ts | 38 ++ 14 files changed, 963 insertions(+), 46 deletions(-) create mode 100644 contracts/SignatureValidatorManager.sol create mode 100644 test/SignatureValidatorManager.spec.ts create mode 100644 test/utils/mockValidationHooksBuilder.ts diff --git a/contracts/SafeProtocolRegistry.sol b/contracts/SafeProtocolRegistry.sol index 54dba868..89dd11b9 100644 --- a/contracts/SafeProtocolRegistry.sol +++ b/contracts/SafeProtocolRegistry.sol @@ -4,8 +4,8 @@ import {ISafeProtocolRegistry} from "./interfaces/Registry.sol"; import {Ownable2Step} from "@openzeppelin/contracts/access/Ownable2Step.sol"; import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; import {Enum} from "./common/Enum.sol"; -import {ISafeProtocolFunctionHandler, ISafeProtocolHooks, ISafeProtocolPlugin} from "./interfaces/Modules.sol"; -import {MODULE_TYPE_PLUGIN, MODULE_TYPE_HOOKS, MODULE_TYPE_FUNCTION_HANDLER} from "./common/Constants.sol"; +import {ISafeProtocolFunctionHandler, ISafeProtocolHooks, ISafeProtocolPlugin, ISafeProtocolSignatureValidator, ISafeProtocolSignatureValidatorHooks} from "./interfaces/Modules.sol"; +import {MODULE_TYPE_PLUGIN, MODULE_TYPE_HOOKS, MODULE_TYPE_FUNCTION_HANDLER, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS, MODULE_TYPE_SIGNATURE_VALIDATOR} from "./common/Constants.sol"; contract SafeProtocolRegistry is ISafeProtocolRegistry, Ownable2Step { mapping(address => ModuleInfo) public listedModules; @@ -17,7 +17,8 @@ contract SafeProtocolRegistry is ISafeProtocolRegistry, Ownable2Step { } error CannotFlagModule(address module); - error CannotAddModule(address module, uint8 moduleTypes); + error ModuleAlreadyListed(address module); + error InvalidModuleType(address module, uint8 givenModuleType); error ModuleDoesNotSupportExpectedInterfaceId(address module, bytes4 expectedInterfaceId); event ModuleAdded(address indexed module); @@ -59,36 +60,48 @@ contract SafeProtocolRegistry is ISafeProtocolRegistry, Ownable2Step { ModuleInfo memory moduleInfo = listedModules[module]; // Check if module is already listed or if moduleTypes is greater than 8. - // Maximum allowed value of moduleTypes is 7. i.e. 2^0 (Plugin) + 2^1 (Function Handler) + 2^2 (Hooks) - if (moduleInfo.listedAt != 0 || moduleTypes > 7) { - revert CannotAddModule(module, moduleTypes); + if (moduleInfo.listedAt != 0) { + revert ModuleAlreadyListed(module); } - // Check if module supports expected interface - if ( - moduleTypes & MODULE_TYPE_HOOKS == MODULE_TYPE_HOOKS && !IERC165(module).supportsInterface(type(ISafeProtocolHooks).interfaceId) - ) { - revert ModuleDoesNotSupportExpectedInterfaceId(module, type(ISafeProtocolHooks).interfaceId); + // Maximum allowed value of moduleTypes is 31. i.e. 2^0 (Plugin) + 2^1 (Function Handler) + 2^2 (Hooks) + 2^3 (Signature Validator hooks) + 2^4 (Signature Validator) + if (moduleTypes > 31) { + revert InvalidModuleType(module, moduleTypes); } - if ( - moduleTypes & MODULE_TYPE_PLUGIN == MODULE_TYPE_PLUGIN && - !IERC165(module).supportsInterface(type(ISafeProtocolPlugin).interfaceId) - ) { - revert ModuleDoesNotSupportExpectedInterfaceId(module, type(ISafeProtocolPlugin).interfaceId); - } - - if ( - moduleTypes & MODULE_TYPE_FUNCTION_HANDLER == MODULE_TYPE_FUNCTION_HANDLER && - !IERC165(module).supportsInterface(type(ISafeProtocolFunctionHandler).interfaceId) - ) { - revert ModuleDoesNotSupportExpectedInterfaceId(module, type(ISafeProtocolFunctionHandler).interfaceId); - } + optionalCheckInterfaceSupport(module, moduleTypes, MODULE_TYPE_PLUGIN, type(ISafeProtocolPlugin).interfaceId); + optionalCheckInterfaceSupport(module, moduleTypes, MODULE_TYPE_FUNCTION_HANDLER, type(ISafeProtocolFunctionHandler).interfaceId); + optionalCheckInterfaceSupport(module, moduleTypes, MODULE_TYPE_HOOKS, type(ISafeProtocolHooks).interfaceId); + optionalCheckInterfaceSupport( + module, + moduleTypes, + MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS, + type(ISafeProtocolSignatureValidatorHooks).interfaceId + ); + optionalCheckInterfaceSupport( + module, + moduleTypes, + MODULE_TYPE_SIGNATURE_VALIDATOR, + type(ISafeProtocolSignatureValidator).interfaceId + ); listedModules[module] = ModuleInfo(uint64(block.timestamp), 0, moduleTypes); emit ModuleAdded(module); } + /** + * @notice This function checks if module supports expected interfaceId. This function will revert if module does not support expected interfaceId. + * @param module Address of the module + * @param moduleTypes uint8 representing the types of module + * @param moduleTypeToCheck uint8 representing the type of module to check + * @param interfaceId bytes4 representing the interfaceId to check + */ + function optionalCheckInterfaceSupport(address module, uint8 moduleTypes, uint8 moduleTypeToCheck, bytes4 interfaceId) internal view { + if (moduleTypes & moduleTypeToCheck == moduleTypeToCheck && !IERC165(module).supportsInterface(interfaceId)) { + revert ModuleDoesNotSupportExpectedInterfaceId(module, interfaceId); + } + } + /** * @notice Allows only owner to flad a module. Only previously added module can be flagged. * This function does not permit flagging a module twice. diff --git a/contracts/SignatureValidatorManager.sol b/contracts/SignatureValidatorManager.sol new file mode 100644 index 00000000..9f6cdcd7 --- /dev/null +++ b/contracts/SignatureValidatorManager.sol @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity ^0.8.18; +import {ISafeProtocolSignatureValidator} from "./interfaces/Modules.sol"; +import {RegistryManager} from "./base/RegistryManager.sol"; +import {ISafeProtocolFunctionHandler, ISafeProtocolSignatureValidatorHooks} from "./interfaces/Modules.sol"; +import {MODULE_TYPE_SIGNATURE_VALIDATOR, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS} from "./common/Constants.sol"; +import {IAccount} from "./interfaces/Accounts.sol"; +import {ISafeProtocolSignatureValidatorManager} from "./interfaces/Manager.sol"; +import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; + +/** + * @title SignatureValidatorManager + * @notice This contract facilitates signature validation. It maintains the signature validator(s) per account per domain or uses a default validation scheme based on the content of the data passed. + * This contract follows the Safe{Core} Protocol specification for signature validation. For more details on specification refer to https://github.com/safe-global/safe-core-protocol-specs. + * Implementaion of this contract is inspired by this pull request: https://github.com/rndlabs/safe-contracts/pull/1/files. + * Expected setup to use this contract for signature validation is as follows: + * - Account enables SafeProtocolManager as a fallback handler + * - Account sets SignatureValidatorManager as a function handler for a function selector e.g. 0x1626ba7e i.e. bytes4(keccak256("isValidSignature(bytes32,bytes)") + * @dev SignatureValidatorManager inherits RegistryManager leading to possible state drift of Registry address with the SafeProtocolManager contract. + * Do not set this contract as a fallback handler of a Safe account. Using this as a fallback handler would allow unauthorised setting of signature validator and signature validator hooks. + */ + +contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHandler, ISafeProtocolSignatureValidatorManager { + constructor(address _registry, address _initialOwner) RegistryManager(_registry, _initialOwner) {} + + // constants + // Signature selector bytes4(keccak256("Account712Signature(bytes32,bytes32,bytes)")); + bytes4 public constant SIGNATURE_VALIDATOR_SELECTOR = 0xb5c726cb; + + // keccak256("SafeMessage(bytes message)"); + bytes32 private constant ACCOUNT_MSG_TYPEHASH = 0x60b3cbf8b4a223d68d641b3b6ddf9a298e7f33710cf3d3a9d1146b5a6150fbca; + + // keccak256( + // "EIP712Domain(uint256 chainId,address verifyingContract)" + // ); + bytes32 private constant DOMAIN_SEPARATOR_TYPEHASH = 0x47e79534a245952e8b16893a336b85a3d9ea9fa8c573f3d803afb92a79469218; + + // Storage + /** + * @notice Mapping to account address => domain separator => signature validator contract + */ + mapping(address => mapping(bytes32 => address)) public signatureValidators; + + /** + * @notice Mapping to account address => signature validator hooks contract + */ + mapping(address => address) public signatureValidatorHooks; + + // Events + /** + * @notice Only one type of event is emitted for simplicity rather one for each individual case: remvoing, updating, adding new signature validator. + */ + event SignatureValidatorChanged(address indexed account, bytes32 indexed domainSeparator, address indexed signatureValidator); + + /** + * @notice Only one type of event is emitted for simplicity rather one for each individual case: remvoing, updating, adding new signature validator hooks. + */ + event SignatureValidatorHooksChanged(address indexed account, address indexed signatureValidatorHooks); + + // Errors + error SignatureValidatorNotSet(address account); + error InvalidMessageHash(bytes32 messageHash); + + /** + * @notice Sets the signature validator contract for an account + * @param signatureValidator Address of the signature validator contract + */ + function setSignatureValidator(bytes32 domainSeparator, address signatureValidator) external { + if (signatureValidator != address(0)) { + checkPermittedModule(signatureValidator, MODULE_TYPE_SIGNATURE_VALIDATOR); + + if (!ISafeProtocolSignatureValidator(signatureValidator).supportsInterface(type(ISafeProtocolSignatureValidator).interfaceId)) + revert ContractDoesNotImplementValidInterfaceId(signatureValidator); + } + signatureValidators[msg.sender][domainSeparator] = signatureValidator; + + emit SignatureValidatorChanged(msg.sender, domainSeparator, signatureValidator); + } + + /** + * @notice Sets the signature validator hooks for an account + * @param signatureValidatorHooksAddress Address of the signature validator hooks contract + */ + function setSignatureValidatorHooks(address signatureValidatorHooksAddress) external override { + if (signatureValidatorHooksAddress != address(0)) { + checkPermittedModule(signatureValidatorHooksAddress, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + + if ( + !ISafeProtocolSignatureValidatorHooks(signatureValidatorHooksAddress).supportsInterface( + type(ISafeProtocolSignatureValidatorHooks).interfaceId + ) + ) revert ContractDoesNotImplementValidInterfaceId(signatureValidatorHooksAddress); + } + signatureValidatorHooks[msg.sender] = signatureValidatorHooksAddress; + + emit SignatureValidatorHooksChanged(msg.sender, signatureValidatorHooksAddress); + } + + /** + * @notice A view function that the Manager will call when an account has enabled this contract as a function handler in the Manager + * @param account Address of the account whose signature validator is to be used + * @param sender Address requesting signature validation + * @param data Calldata containing the 4 bytes function selector, 32 bytes message hash and payload. + * Layout of data: + * 0x00 to 0x04 - 4 bytes function selector when this contract is set as a function handler in the SafeProtocolManager i.e. 0x1626ba7e + * 0x04 to 0x24 - 32 bytes hash of the message used for signing + * 0x24 to end - bytes containing signatures or signatureData either one of the below: + * If first 4 bytes of signatureData are 0xb5c726cb i.e. bytes4(keccak256("Account712Signature(bytes32,bytes32,bytes)")); then it will be interpreted as follows: + * payload = abi.encodeWithSelector(0xb5c726cb, abi.encode(domainSeparator, structHash, signatures) + * Layout of `data` parameter in this case: + * 0x00 to 0x04 - 4 bytes function selector when this contract is set as a function handler in the SafeProtocolManager i.e. 0x1626ba7e + * 0x04 to 0x24 - 32 bytes hash of the signed message + * 0x24 to 0x44 - 32 bytes offset to the start of `bytes` parameter + * 0x44 to 0x64 - 32 bytes length of `bytes` parameter + * 0x64 to 0x68 - 4 bytes of Signature selector + * 0x68 to 0x88 - 32 bytes domain separator + * 0x88 to 0xa8 - 32 bytes struct hash + * 0xa8 to end - contains offset, length of bytes, and actual bytes containing signatures + * Else: + * bytes containing signature data + * default validation flow will be used which will depend on the account implementation + * + */ + function handle( + address account, + address sender, + uint256 /* value */, + bytes calldata data + ) external view override returns (bytes memory) { + // Skip first 4 bytes of data as it contains function selector + (bytes32 messageHash, bytes memory signatureData) = abi.decode(data[0x4:], (bytes32, bytes)); + + address signatureValidatorHooksAddress = signatureValidatorHooks[account]; + bytes memory prevalidationData; + bytes memory returnData; + + if (signatureValidatorHooksAddress != address(0)) { + checkPermittedModule(signatureValidatorHooksAddress, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + prevalidationData = ISafeProtocolSignatureValidatorHooks(signatureValidatorHooksAddress).preValidationHook( + account, + sender, + data + ); + } + + if (bytes4(data[0x64:0x68]) == SIGNATURE_VALIDATOR_SELECTOR) { + returnData = abi.encode(validateWithSignatureValdiator(account, sender, messageHash, data[0x68:])); + } else { + returnData = defaultValidator(account, messageHash, signatureData); + } + + if (signatureValidatorHooksAddress != address(0)) { + ISafeProtocolSignatureValidatorHooks(signatureValidatorHooksAddress).postValidationHook(account, prevalidationData); + } + return returnData; + } + + /** + * @notice A view function for default signature validation flow. + * @param account Address of the account whose signature is to be validated. Account should support function `checkSignatures(bytes32, bytes, bytes)` + * @param hash bytes32 hash of the data that is signed + * @param signatures Arbitrary length bytes array containing the signatures + */ + function defaultValidator(address account, bytes32 hash, bytes memory signatures) internal view returns (bytes memory) { + bytes memory messageData = abi.encodePacked( + bytes1(0x19), + bytes1(0x01), + keccak256(abi.encode(DOMAIN_SEPARATOR_TYPEHASH, block.chainid, account)), + keccak256(abi.encode(ACCOUNT_MSG_TYPEHASH, keccak256(abi.encode(hash)))) + ); + bytes32 messageHash = keccak256(messageData); + IAccount(account).checkSignatures(messageHash, messageData, signatures); + // bytes4(keccak256("isValidSignature(bytes32,bytes)") + return abi.encode(0x1626ba7e); + } + + /** + * + * @param account Address of the account whose signature is to be validated + * @param sender Address of the entitty that requested for signature validation + * @param messageHash Hash of the message that is signed + * @param data Arbitrary length bytes array containing the domain separator, struct hash and signatures + */ + function validateWithSignatureValdiator( + address account, + address sender, + bytes32 messageHash, + bytes calldata data + ) internal view returns (bytes4) { + (bytes32 domainSeparator, bytes32 structHash, bytes memory signatures) = abi.decode(data, (bytes32, bytes32, bytes)); + + if (keccak256(abi.encodePacked(bytes1(0x19), bytes1(0x01), domainSeparator, structHash)) != messageHash) { + revert InvalidMessageHash(messageHash); + } + + address signatureValidator = signatureValidators[account][domainSeparator]; + if (signatureValidator == address(0)) { + revert SignatureValidatorNotSet(account); + } + + checkPermittedModule(signatureValidator, MODULE_TYPE_SIGNATURE_VALIDATOR); + + return + ISafeProtocolSignatureValidator(signatureValidator).isValidSignature( + account, + sender, + messageHash, + domainSeparator, + structHash, + signatures + ); + } + + /** + * @notice A function that returns module information. + * @return providerType uint256 Type of metadata provider + * @return location Arbitrary length bytes data containing the location of the metadata provider + */ + function metadataProvider() external view override returns (uint256 providerType, bytes memory location) {} + + /** + * @param interfaceId bytes4 interface id to be checked + * @return true if interface is supported + */ + function supportsInterface(bytes4 interfaceId) external view override returns (bool) { + return + interfaceId == type(IERC165).interfaceId || + interfaceId == type(ISafeProtocolSignatureValidatorManager).interfaceId || + interfaceId == type(ISafeProtocolFunctionHandler).interfaceId; + } +} diff --git a/contracts/common/Constants.sol b/contracts/common/Constants.sol index f64ff3ee..36d10e6f 100644 --- a/contracts/common/Constants.sol +++ b/contracts/common/Constants.sol @@ -11,3 +11,5 @@ uint8 constant PLUGIN_PERMISSION_EXECUTE_DELEGATECALL = 4; uint8 constant MODULE_TYPE_PLUGIN = 1; uint8 constant MODULE_TYPE_FUNCTION_HANDLER = 2; uint8 constant MODULE_TYPE_HOOKS = 4; +uint8 constant MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS = 8; +uint8 constant MODULE_TYPE_SIGNATURE_VALIDATOR = 16; diff --git a/contracts/interfaces/Accounts.sol b/contracts/interfaces/Accounts.sol index a08ebf5a..475938f8 100644 --- a/contracts/interfaces/Accounts.sol +++ b/contracts/interfaces/Accounts.sol @@ -18,4 +18,6 @@ interface IAccount { bytes memory data, uint8 operation ) external returns (bool success, bytes memory returnData); + + function checkSignatures(bytes32 dataHash, bytes memory data, bytes memory signatures) external view; } diff --git a/contracts/interfaces/Manager.sol b/contracts/interfaces/Manager.sol index a9f5425c..90ad13fc 100644 --- a/contracts/interfaces/Manager.sol +++ b/contracts/interfaces/Manager.sol @@ -30,3 +30,16 @@ interface ISafeProtocolManager { */ function executeRootAccess(address account, SafeRootAccess calldata rootAccess) external returns (bytes memory data); } + +interface ISafeProtocolSignatureValidatorManager { + /** + * @param domainSeparator bytes32 containing the domain for which Signature Validator contract should be used + * @param signatureValidatorContract Address of the Signature Validator Contract implementing ISafeProtocolSignatureValidator interface + */ + function setSignatureValidator(bytes32 domainSeparator, address signatureValidatorContract) external; + + /** + * @param signatureValidatorHooksContract Address of the contract to be used as Hooks for Signature Validator implementing ISignatureValidatorHook interface + */ + function setSignatureValidatorHooks(address signatureValidatorHooksContract) external; +} diff --git a/contracts/interfaces/Modules.sol b/contracts/interfaces/Modules.sol index 7ccc7d1d..f18640d8 100644 --- a/contracts/interfaces/Modules.sol +++ b/contracts/interfaces/Modules.sol @@ -138,3 +138,40 @@ interface ISafeProtocolPlugin is IERC165 { */ function requiresPermissions() external view returns (uint8 permissions); } + +interface ISafeProtocolSignatureValidator is IERC165 { + /** + * @param account The account that has delegated the signature verification + * @param sender The address that originally called the Safe's `isValidSignature` method + * @param structHash The EIP-712 hash whose signature will be verified + * @param domainSeparator The EIP-712 domainSeparator + * @param structHash The EIP-712 structHash + * @param payload An arbitrary payload that can be used to pass additional data to the validator + * @return magic The magic value that should be returned if the signature is valid (0x1626ba7e) + */ + function isValidSignature( + address account, + address sender, + bytes32 messageHash, + bytes32 domainSeparator, + bytes32 structHash, + bytes calldata payload + ) external view returns (bytes4 magic); +} + +interface ISafeProtocolSignatureValidatorHooks is IERC165 { + /** + * @param account Address of the account for which signature is being validated + * @param validator Address of the validator contract to be used for signature validation. This address will be account address in case of default signature validation flow is used. + * @param payload The payload provided for the validation + * @return result bytes containing the result + */ + function preValidationHook(address account, address validator, bytes calldata payload) external view returns (bytes memory result); + + /** + * @param account Address of the account for which signature is being validated + * @param preValidationData Data returned by preValidationHook + * @return result bytes containing the result + */ + function postValidationHook(address account, bytes calldata preValidationData) external view returns (bytes memory result); +} diff --git a/contracts/test/TestExecutor.sol b/contracts/test/TestExecutor.sol index e5b83392..0f21b3cc 100644 --- a/contracts/test/TestExecutor.sol +++ b/contracts/test/TestExecutor.sol @@ -127,5 +127,9 @@ contract TestExecutor is IAccount { } } + function checkSignatures(bytes32 dataHash, bytes calldata data, bytes calldata signatures) external view { + // An empty function used for testing signature validator flow + } + receive() external payable {} } diff --git a/deploy/deploy_protocol.ts b/deploy/deploy_protocol.ts index 39e1187d..8fe0b1fd 100644 --- a/deploy/deploy_protocol.ts +++ b/deploy/deploy_protocol.ts @@ -18,6 +18,13 @@ const deploy: DeployFunction = async function (hre: HardhatRuntimeEnvironment) { log: true, deterministicDeployment: true, }); + + await deploy("SignatureValidatorManager", { + from: deployer, + args: [registry.address, owner], + log: true, + deterministicDeployment: true, + }); }; deploy.tags = ["protocol"]; diff --git a/src/utils/constants.ts b/src/utils/constants.ts index cec7b653..0c4cfc52 100644 --- a/src/utils/constants.ts +++ b/src/utils/constants.ts @@ -1,9 +1,15 @@ -export const PLUGIN_PERMISSION_NONE = 0n; -export const PLUGIN_PERMISSION_EXECUTE_CALL = 1n; -export const PLUGIN_PERMISSION_CALL_TO_SELF = 2n; -export const PLUGIN_PERMISSION_DELEGATE_CALL = 4n; +export const PLUGIN_PERMISSION_NONE: number = 0; +export const PLUGIN_PERMISSION_EXECUTE_CALL: number = 1; +export const PLUGIN_PERMISSION_CALL_TO_SELF: number = 2; +export const PLUGIN_PERMISSION_DELEGATE_CALL: number = 4; // Module types -export const MODULE_TYPE_PLUGIN = 1n; -export const MODULE_TYPE_FUNCTION_HANDLER = 2n; -export const MODULE_TYPE_HOOKS = 4n; +export const MODULE_TYPE_PLUGIN: number = 1; +export const MODULE_TYPE_FUNCTION_HANDLER: number = 2; +export const MODULE_TYPE_HOOKS: number = 4; +export const MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS: number = 8; +export const MODULE_TYPE_SIGNATURE_VALIDATOR: number = 16; + +// solidity: bytes4(keccak256("Account712Signature(bytes32,bytes32,bytes)")); +// javascript: hre.ethers.keccak256(toUtf8Bytes("Account712Signature(bytes32,bytes32,bytes)")).slice(0, 10); +export const SIGNATURE_VALIDATOR_SELECTOR = "0xb5c726cb"; diff --git a/test/SafeProtocolManager.spec.ts b/test/SafeProtocolManager.spec.ts index 75f2d4ff..47475f7d 100644 --- a/test/SafeProtocolManager.spec.ts +++ b/test/SafeProtocolManager.spec.ts @@ -182,7 +182,7 @@ describe("SafeProtocolManager", async () => { // As SafeProtocolManager is a fallback handler on a contract, call to enablePlugin(...) function will be // forwarded for SafeProtocolManager. Direct calls to SafeProtocolManager to enable plugin are intentionally blocked. await account.exec(account.target, 0, data); - expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_EXECUTE_CALL, SENTINEL_MODULES, ]); @@ -202,7 +202,7 @@ describe("SafeProtocolManager", async () => { safeProtocolManager, "PluginPermissionsMismatch", ); - expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_NONE, ZeroAddress, ]); @@ -270,14 +270,14 @@ describe("SafeProtocolManager", async () => { PLUGIN_PERMISSION_EXECUTE_CALL, ]); await account.exec(accountAddress, 0, data); - expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_EXECUTE_CALL, SENTINEL_MODULES, ]); const data2 = safeProtocolManager.interface.encodeFunctionData("disablePlugin", [SENTINEL_MODULES, pluginAddress]); await account.exec(accountAddress, 0, data2); - expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_NONE, ZeroAddress, ]); @@ -294,7 +294,7 @@ describe("SafeProtocolManager", async () => { PLUGIN_PERMISSION_EXECUTE_CALL, ]); await account.exec(account.target, 0, data); - expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(account.target, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_EXECUTE_CALL, SENTINEL_MODULES, ]); @@ -453,14 +453,14 @@ describe("SafeProtocolManager", async () => { PLUGIN_PERMISSION_EXECUTE_CALL, ]); await account.exec(account.target, 0, data); - expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_EXECUTE_CALL, SENTINEL_MODULES, ]); const data2 = safeProtocolManager.interface.encodeFunctionData("disablePlugin", [SENTINEL_MODULES, pluginAddress]); await account.exec(account.target, 0, data2); - expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(accountAddress, pluginAddress)).to.deep.equal([ PLUGIN_PERMISSION_NONE, ZeroAddress, ]); @@ -505,7 +505,7 @@ describe("SafeProtocolManager", async () => { // Disable plugin 2 const data2 = safeProtocolManager.interface.encodeFunctionData("disablePlugin", [plugin3.target, plugin2.target]); await account.exec(account.target, 0, data2); - expect(await safeProtocolManager.getPluginInfo(account.target, plugin2.target)).to.eql([ + expect(await safeProtocolManager.getPluginInfo(account.target, plugin2.target)).to.deep.equal([ PLUGIN_PERMISSION_NONE, ZeroAddress, ]); diff --git a/test/SafeProtocolRegistry.spec.ts b/test/SafeProtocolRegistry.spec.ts index 12f9b514..32caa846 100644 --- a/test/SafeProtocolRegistry.spec.ts +++ b/test/SafeProtocolRegistry.spec.ts @@ -2,7 +2,13 @@ import hre, { ethers, deployments } from "hardhat"; import { expect } from "chai"; import { AddressZero } from "@ethersproject/constants"; import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; -import { MODULE_TYPE_PLUGIN, MODULE_TYPE_HOOKS, MODULE_TYPE_FUNCTION_HANDLER } from "../src/utils/constants"; +import { + MODULE_TYPE_PLUGIN, + MODULE_TYPE_HOOKS, + MODULE_TYPE_FUNCTION_HANDLER, + MODULE_TYPE_SIGNATURE_VALIDATOR, + MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS, +} from "../src/utils/constants"; import { getHooksWithPassingChecks, getHooksWithFailingCallToSupportsInterfaceMethod } from "./utils/mockHooksBuilder"; import { getPluginWithFailingCallToSupportsInterfaceMethod } from "./utils/mockPluginBuilder"; import { getFunctionHandlerWithFailingCallToSupportsInterfaceMethod } from "./utils/mockFunctionHandlerBuilder"; @@ -18,7 +24,7 @@ describe("SafeProtocolRegistry", async () => { }); // A helper function to convert a number to a bytes32 value - const numberToBytes32 = (value: bigint) => hre.ethers.zeroPadValue(hre.ethers.toBeHex(value), 32); + const numberToBytes32 = (value: number) => hre.ethers.zeroPadValue(hre.ethers.toBeHex(value), 32); it("Should allow adding a module only once", async () => { const { safeProtocolRegistry } = await setupTests(); @@ -27,7 +33,7 @@ describe("SafeProtocolRegistry", async () => { await safeProtocolRegistry.connect(owner).addModule(mockHookAddress, MODULE_TYPE_HOOKS); await expect(safeProtocolRegistry.connect(owner).addModule(mockHookAddress, MODULE_TYPE_HOOKS)).to.be.revertedWithCustomError( safeProtocolRegistry, - "CannotAddModule", + "ModuleAlreadyListed", ); }); @@ -39,11 +45,20 @@ describe("SafeProtocolRegistry", async () => { await safeProtocolRegistry .connect(owner) - .addModule(mockModule, MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER + MODULE_TYPE_HOOKS); + .addModule( + mockModule, + MODULE_TYPE_PLUGIN + + MODULE_TYPE_FUNCTION_HANDLER + + MODULE_TYPE_HOOKS + + MODULE_TYPE_SIGNATURE_VALIDATOR + + MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS, + ); + const [listedAt, flaggedAt] = await safeProtocolRegistry.check.staticCall( mockModule.target, numberToBytes32(MODULE_TYPE_FUNCTION_HANDLER), ); + expect(listedAt).to.be.greaterThan(0); expect(flaggedAt).to.be.equal(0); @@ -54,15 +69,29 @@ describe("SafeProtocolRegistry", async () => { const [listedAt3, flaggedAt3] = await safeProtocolRegistry.check.staticCall(mockModule.target, numberToBytes32(MODULE_TYPE_HOOKS)); expect(listedAt3).to.be.greaterThan(0); expect(flaggedAt3).to.be.equal(0); + + const [listedAt4, flaggedAt4] = await safeProtocolRegistry.check.staticCall( + mockModule.target, + numberToBytes32(MODULE_TYPE_SIGNATURE_VALIDATOR), + ); + expect(listedAt4).to.be.greaterThan(0); + expect(flaggedAt4).to.be.equal(0); + + const [listedAt5, flaggedAt5] = await safeProtocolRegistry.check.staticCall( + mockModule.target, + numberToBytes32(MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS), + ); + expect(listedAt5).to.be.greaterThan(0); + expect(flaggedAt5).to.be.equal(0); }); it("Should not allow adding a module with invalid moduleTypes", async () => { const { safeProtocolRegistry } = await setupTests(); const mockHookAddress = (await getHooksWithPassingChecks()).target; - await expect(safeProtocolRegistry.connect(owner).addModule(mockHookAddress, 8)) - .to.be.revertedWithCustomError(safeProtocolRegistry, "CannotAddModule") - .withArgs(mockHookAddress, 8); + await expect(safeProtocolRegistry.connect(owner).addModule(mockHookAddress, 32)) + .to.be.revertedWithCustomError(safeProtocolRegistry, "InvalidModuleType") + .withArgs(mockHookAddress, 32); }); it("Should not allow non-owner to add a module", async () => { @@ -166,4 +195,24 @@ describe("SafeProtocolRegistry", async () => { .to.be.revertedWithCustomError(safeProtocolRegistry, "ModuleDoesNotSupportExpectedInterfaceId") .withArgs(mockFunctionHandlerAddress, "0xf601ad15"); }); + + it("Should revert when signature validator hooks not supporting expected interfaceId", async () => { + const { safeProtocolRegistry } = await setupTests(); + const mockContract = await (await hre.ethers.getContractFactory("MockContract")).deploy(); + await mockContract.givenMethodReturnBool("0x01ffc9a7", false); + + await expect(safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS)) + .to.be.revertedWithCustomError(safeProtocolRegistry, "ModuleDoesNotSupportExpectedInterfaceId") + .withArgs(mockContract.target, "0xd340d5af"); + }); + + it("Should revert when signature validator not supporting expected interfaceId", async () => { + const { safeProtocolRegistry } = await setupTests(); + const mockContract = await (await hre.ethers.getContractFactory("MockContract")).deploy(); + await mockContract.givenMethodReturnBool("0x01ffc9a7", false); + + await expect(safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR)) + .to.be.revertedWithCustomError(safeProtocolRegistry, "ModuleDoesNotSupportExpectedInterfaceId") + .withArgs(mockContract.target, "0x38c8d4e6"); + }); }); diff --git a/test/SignatureValidatorManager.spec.ts b/test/SignatureValidatorManager.spec.ts new file mode 100644 index 00000000..d6ccc62b --- /dev/null +++ b/test/SignatureValidatorManager.spec.ts @@ -0,0 +1,503 @@ +import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; +import hre, { deployments } from "hardhat"; +import { getRegistry, getSafeProtocolManager, getSignatureValidatorManager } from "./utils/contracts"; +import { MaxUint256, ZeroAddress } from "ethers"; +import { + MODULE_TYPE_FUNCTION_HANDLER, + MODULE_TYPE_SIGNATURE_VALIDATOR, + MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS, +} from "../src/utils/constants"; +import { expect } from "chai"; +import { + getMockSignatureValidationHooks, + getMockSignatureValidationHooksWithFailingPostValidationHook, + getMockSignatureValidationHooksWithFailingPreValidationHook, +} from "./utils/mockValidationHooksBuilder"; +import { SIGNATURE_VALIDATOR_SELECTOR } from "../src/utils/constants"; + +describe("SignatureValidatorManager", () => { + let deployer: SignerWithAddress, owner: SignerWithAddress; + + const isValidSignatureInterface = new hre.ethers.Interface(["function isValidSignature(bytes32,bytes) public view returns (bytes4)"]); + + before(async () => { + [deployer, owner] = await hre.ethers.getSigners(); + }); + + const setupTests = deployments.createFixture(async ({ deployments }) => { + await deployments.fixture(); + + const safeProtocolSignatureValidatorManager = await getSignatureValidatorManager(); + const safeProtocolManager = await getSafeProtocolManager(); + + const safeProtocolRegistry = await getRegistry(); + await safeProtocolRegistry.connect(owner).addModule(safeProtocolSignatureValidatorManager.target, MODULE_TYPE_FUNCTION_HANDLER); + + const account = await hre.ethers.deployContract("TestExecutor", [safeProtocolManager.target], { signer: deployer }); + + return { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry }; + }); + + it("should revert when enabling a signature validator not implementing ISafeProtocolSignatureValidator interface", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupTests(); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR); + + await mockContract.givenMethodReturnBool("0x01ffc9a7", false); + + const domainSeparator = hre.ethers.randomBytes(32); + + const dataSetValidator = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidator", [ + domainSeparator, + mockContract.target, + ]); + + await expect(account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidator, MaxUint256)) + .to.be.revertedWithCustomError(safeProtocolSignatureValidatorManager, "ContractDoesNotImplementValidInterfaceId") + .withArgs(mockContract.target); + }); + + it("should allow to remove signature validator", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupTests(); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR); + + const domainSeparator = hre.ethers.randomBytes(32); + + const dataSetValidator = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidator", [ + domainSeparator, + mockContract.target, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidator, MaxUint256); + + expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal( + mockContract.target, + ); + + const dataResetValidator = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidator", [ + domainSeparator, + hre.ethers.ZeroAddress, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataResetValidator, MaxUint256); + expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal(ZeroAddress); + }); + + it("should revert when enabling a signature validator hooks not implementing ISafeProtocolSignatureValidatorHooks interface", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupTests(); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + + await mockContract.givenMethodReturnBool("0x01ffc9a7", false); + + const dataSetValidatorHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidatorHooks", [ + mockContract.target, + ]); + + await expect(account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidatorHooks, MaxUint256)) + .to.be.revertedWithCustomError(safeProtocolSignatureValidatorManager, "ContractDoesNotImplementValidInterfaceId") + .withArgs(mockContract.target); + }); + + it("should allow to remove signature validator hooks", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupTests(); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + + const dataSetValidatorHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidatorHooks", [ + mockContract.target, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidatorHooks, MaxUint256); + + expect(await safeProtocolSignatureValidatorManager.signatureValidatorHooks(account.target)).to.be.equal(mockContract.target); + + const dataResetValidator = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidatorHooks", [ + hre.ethers.ZeroAddress, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataResetValidator, MaxUint256); + expect(await safeProtocolSignatureValidatorManager.signatureValidatorHooks(account.target)).to.be.equal(ZeroAddress); + }); + + describe("signature validation per domain separator", async () => { + const createPayloadWithSelector = (domainSeparator: Uint8Array, structHash: Uint8Array, signatures: Uint8Array) => { + const encodedData = new hre.ethers.AbiCoder().encode( + ["bytes32", "bytes32", "bytes"], + [domainSeparator, structHash, signatures], + ); + + const encodeDataWithSelector = hre.ethers.solidityPacked(["bytes4", "bytes"], [SIGNATURE_VALIDATOR_SELECTOR, encodedData]); + + const messageHash = hre.ethers.keccak256( + hre.ethers.solidityPacked(["bytes1", "bytes1", "bytes32", "bytes32"], ["0x19", "0x01", domainSeparator, structHash]), + ); + + const data = isValidSignatureInterface.encodeFunctionData("isValidSignature", [messageHash, encodeDataWithSelector]); + + return data; + }; + + it("Should revert if signature validator is not registered", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager } = await setupTests(); + + await account.setFallbackHandler(safeProtocolManager.target); + + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + const data = createPayloadWithSelector(hre.ethers.randomBytes(32), hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + + await expect(account.executeCallViaMock(account.target, 0, data, MaxUint256)) + .to.be.revertedWithCustomError(safeProtocolSignatureValidatorManager, "SignatureValidatorNotSet") + .withArgs(account.target); + }); + + it("Should call signature validator without validation hooks", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry } = await setupTests(); + + // 1. Set fallback handler + await account.setFallbackHandler(safeProtocolManager.target); + + // 2. Set function handler + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + // 0x38c8d4e6 => isValidSignature(address,address,bytes32,bytes32,bytes32,bytes) + const signatureValidatorReturnValue = new hre.ethers.AbiCoder().encode(["bytes4"], ["0x12345678"]); + await mockContract.givenMethodReturn("0x38c8d4e6", signatureValidatorReturnValue); + + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR); + + const domainSeparator = hre.ethers.randomBytes(32); + + // 3. Set validator for domain separator + const dataSetValidatorManager = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidator", [ + domainSeparator, + mockContract.target, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidatorManager, MaxUint256); + + const data = createPayloadWithSelector(domainSeparator, hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + + expect(await account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.deep.equal([ + true, + signatureValidatorReturnValue, + ]); + }); + + it("Should revert if invalid message hash is passed", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry } = await setupTests(); + + // 1. Set fallback handler + await account.setFallbackHandler(safeProtocolManager.target); + + // 2. Set function handler + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + // 0x38c8d4e6 => isValidSignature(address,address,bytes32,bytes32,bytes32,bytes) + const signatureValidatorReturnValue = new hre.ethers.AbiCoder().encode(["bytes4"], ["0x12345678"]); + await mockContract.givenMethodReturn("0x38c8d4e6", signatureValidatorReturnValue); + + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR); + + const domainSeparator = hre.ethers.randomBytes(32); + + // 3. Set validator for domain separator + const dataSetValidatorManager = safeProtocolSignatureValidatorManager.interface.encodeFunctionData("setSignatureValidator", [ + domainSeparator, + mockContract.target, + ]); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidatorManager, MaxUint256); + + let data = createPayloadWithSelector(domainSeparator, hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + // replace the message hash + data = data.substring(0, 10) + hre.ethers.hexlify(hre.ethers.randomBytes(32)).slice(2) + data.substring(74); + await expect(account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.revertedWithCustomError( + safeProtocolSignatureValidatorManager, + "InvalidMessageHash", + ); + }); + + describe("Validation with Hooks", async () => { + const domainSeparator = hre.ethers.randomBytes(32); + const signatureValidatorReturnValue = new hre.ethers.AbiCoder().encode(["bytes4"], ["0x12345678"]); + + const setupHooksTests = deployments.createFixture(async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry } = await setupTests(); + + // 1. Set fallback handler + await account.setFallbackHandler(safeProtocolManager.target); + + // 2. Set function handler + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + // set up mock contract as a signature validator + const mockContract = await hre.ethers.deployContract("MockContract", { signer: deployer }); + // 0x38c8d4e6 => isValidSignature(address,address,bytes32,bytes32,bytes32,bytes) + await mockContract.givenMethodReturn("0x38c8d4e6", signatureValidatorReturnValue); + await mockContract.givenMethodReturnBool("0x01ffc9a7", true); + + await safeProtocolRegistry.connect(owner).addModule(mockContract.target, MODULE_TYPE_SIGNATURE_VALIDATOR); + + // 3. Set validator for domain separator + const dataSetValidatorManager = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidator", + [domainSeparator, mockContract.target], + ); + + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidatorManager, MaxUint256); + + return { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry }; + }); + + it("Should revert if pre-validation fails", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupHooksTests(); + + // Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooksWithFailingPreValidationHook(); + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + const data = createPayloadWithSelector(domainSeparator, hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + + await expect(account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.revertedWith( + "Pre-validation failed", + ); + }); + + it("Should revert if post-validation fails", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupHooksTests(); + + // Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooksWithFailingPostValidationHook(); + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + const data = createPayloadWithSelector(domainSeparator, hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + + await expect(account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.revertedWith( + "Post-validation failed", + ); + }); + + it("Should call signature validator with validation hooks", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupHooksTests(); + + // Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooks(); + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + const data = createPayloadWithSelector(domainSeparator, hre.ethers.randomBytes(32), hre.ethers.randomBytes(64)); + + expect(await account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.deep.equal([ + true, + signatureValidatorReturnValue, + ]); + }); + }); + }); + + describe("default signature validation flow", async () => { + it("Should call default signature validator without validation hooks", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager } = await setupTests(); + + // 1. Set fallback handler + await account.setFallbackHandler(safeProtocolManager.target); + + // 2. Set function handler + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + const data = isValidSignatureInterface.encodeFunctionData("isValidSignature", [ + hre.ethers.randomBytes(32), + hre.ethers.randomBytes(65), + ]); + + expect(await account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.deep.equal([ + true, + "0x000000000000000000000000000000000000000000000000000000001626ba7e", + ]); + }); + + describe("Validation with Hooks", async () => { + const setupHooksTests = deployments.createFixture(async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry } = await setupTests(); + + // 1. Set fallback handler + await account.setFallbackHandler(safeProtocolManager.target); + + // 2. Set function handler + const setFunctionHandlerData = safeProtocolManager.interface.encodeFunctionData("setFunctionHandler", [ + "0x1626ba7e", + safeProtocolSignatureValidatorManager.target, + ]); + await account.executeCallViaMock(account.target, 0, setFunctionHandlerData, MaxUint256); + + // 3. Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooks(); + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + return { account, safeProtocolSignatureValidatorManager, safeProtocolManager, safeProtocolRegistry }; + }); + + it("Should call default signature validator with validation hooks", async () => { + const { account } = await setupHooksTests(); + + const data = isValidSignatureInterface.encodeFunctionData("isValidSignature", [ + hre.ethers.randomBytes(32), + hre.ethers.randomBytes(65), + ]); + + expect(await account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.deep.equal([ + true, + "0x000000000000000000000000000000000000000000000000000000001626ba7e", + ]); + }); + + it("Should revert if pre-validation fails", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupHooksTests(); + + // Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooksWithFailingPreValidationHook(); + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + const data = isValidSignatureInterface.encodeFunctionData("isValidSignature", [ + hre.ethers.randomBytes(32), + hre.ethers.randomBytes(65), + ]); + + await expect(account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.revertedWith( + "Pre-validation failed", + ); + }); + + it("Should revert if post-validation fails", async () => { + const { account, safeProtocolSignatureValidatorManager, safeProtocolRegistry } = await setupHooksTests(); + + // Set validation hooks + const mockSignatureValidatorHooks = await getMockSignatureValidationHooksWithFailingPostValidationHook(); + const dataSetValidationHooks = safeProtocolSignatureValidatorManager.interface.encodeFunctionData( + "setSignatureValidatorHooks", + [mockSignatureValidatorHooks.target], + ); + await safeProtocolRegistry + .connect(owner) + .addModule(mockSignatureValidatorHooks.target, MODULE_TYPE_SIGNATURE_VALIDATOR_HOOKS); + await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidationHooks, MaxUint256); + + const data = isValidSignatureInterface.encodeFunctionData("isValidSignature", [ + hre.ethers.randomBytes(32), + hre.ethers.randomBytes(65), + ]); + + await expect(account.executeCallViaMock.staticCall(account.target, 0, data, MaxUint256)).to.be.revertedWith( + "Post-validation failed", + ); + }); + }); + }); + + it("call metadataProvider() for increasing coverage", async () => { + const { safeProtocolSignatureValidatorManager } = await setupTests(); + expect(await safeProtocolSignatureValidatorManager.metadataProvider()); + }); + + it("Should return true when valid interfaceId is passed", async () => { + const { safeProtocolSignatureValidatorManager } = await setupTests(); + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0x01ffc9a7")).to.be.true; + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0x86080c7a")).to.be.true; + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0xf601ad15")).to.be.true; + }); + + it("Should return false when invalid interfaceId is passed", async () => { + const { safeProtocolSignatureValidatorManager } = await setupTests(); + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0x00000000")).to.be.false; + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0xbaddad42")).to.be.false; + expect(await safeProtocolSignatureValidatorManager.supportsInterface.staticCall("0xffffffff")).to.be.false; + }); +}); diff --git a/test/utils/contracts.ts b/test/utils/contracts.ts index ff3a565a..4e156a79 100644 --- a/test/utils/contracts.ts +++ b/test/utils/contracts.ts @@ -1,6 +1,18 @@ import { Addressable, BaseContract } from "ethers"; import hre from "hardhat"; +import { HardhatRuntimeEnvironment } from "hardhat/types"; +import { SafeProtocolManager, SafeProtocolRegistry, SignatureValidatorManager } from "../../typechain-types"; + export const getInstance = async (name: string, address: string | Addressable): Promise => { // TODO: this typecasting should be refactored return (await hre.ethers.getContractAt(name, address)) as unknown as T; }; + +export const getSingleton = async (hre: HardhatRuntimeEnvironment, name: string): Promise => { + const deployment = await hre.deployments.get(name); + return getInstance(name, deployment.address); +}; + +export const getSignatureValidatorManager = async () => getSingleton(hre, "SignatureValidatorManager"); +export const getSafeProtocolManager = async () => getSingleton(hre, "SafeProtocolManager"); +export const getRegistry = async () => getSingleton(hre, "SafeProtocolRegistry"); diff --git a/test/utils/mockValidationHooksBuilder.ts b/test/utils/mockValidationHooksBuilder.ts new file mode 100644 index 00000000..4a0605a2 --- /dev/null +++ b/test/utils/mockValidationHooksBuilder.ts @@ -0,0 +1,38 @@ +import hre from "hardhat"; +import { ISafeProtocolSignatureValidatorHooks } from "../../typechain-types"; + +export const getMockSignatureValidationHooks = async (): Promise => { + const signatureValidationHooks = await (await hre.ethers.getContractFactory("MockContract")).deploy(); + + // Supports IERC165 + await signatureValidationHooks.givenMethodReturnBool("0x01ffc9a7", true); + + // 0x3964efae => preValidationHook(address,address,bytes) + // 0xea243a01 => postValidationHook(address,bytes) + await signatureValidationHooks.givenMethodReturn("0x3964efae", hre.ethers.AbiCoder.defaultAbiCoder().encode(["bytes"], ["0x1234"])); + await signatureValidationHooks.givenMethodReturn("0xea243a01", hre.ethers.AbiCoder.defaultAbiCoder().encode(["bytes"], ["0x1234"])); + + return hre.ethers.getContractAt("ISafeProtocolSignatureValidatorHooks", signatureValidationHooks.target); +}; + +export const getMockSignatureValidationHooksWithFailingPreValidationHook = async (): Promise => { + const signatureValidationHooks = await (await hre.ethers.getContractFactory("MockContract")).deploy(); + + // Supports IERC165 + await signatureValidationHooks.givenMethodReturnBool("0x01ffc9a7", true); + + // 0x3964efae => preValidationHook(address,address,bytes) + await signatureValidationHooks.givenMethodRevertWithMessage("0x3964efae", "Pre-validation failed"); + return hre.ethers.getContractAt("ISafeProtocolSignatureValidatorHooks", signatureValidationHooks.target); +}; + +export const getMockSignatureValidationHooksWithFailingPostValidationHook = async (): Promise => { + const signatureValidationHooks = await (await hre.ethers.getContractFactory("MockContract")).deploy(); + + // Supports IERC165 + await signatureValidationHooks.givenMethodReturnBool("0x01ffc9a7", true); + + // 0xea243a01 => postValidationHook(address,bytes) + await signatureValidationHooks.givenMethodRevertWithMessage("0xea243a01", "Post-validation failed"); + return hre.ethers.getContractAt("ISafeProtocolSignatureValidatorHooks", signatureValidationHooks.target); +};