diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index adfc1e7a..cc38a45c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -25,7 +25,7 @@ If you are submitting a bug report, please: Please send a [GitHub Pull Request to safe-core-protocol repository](https://github.com/safe-global/safe-core-protocol) with a clear description of the proposed changes. Each pull request should be associated with an issue and should be made against the `main` branch. Branch naming convention: - + - For a new feature, use `feature--short-description` - For a bug fix, use `fix--short-description` @@ -40,4 +40,4 @@ Steps to be taken before submitting a pull request to be considered for review: - Make sure there are no linting errors Thanks, -Safe team \ No newline at end of file +Safe team diff --git a/contracts/SafeProtocolManager.sol b/contracts/SafeProtocolManager.sol index 7012b720..93f6f0a1 100644 --- a/contracts/SafeProtocolManager.sol +++ b/contracts/SafeProtocolManager.sol @@ -24,7 +24,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana /** * @notice Mapping of a mapping what stores information about plugins that are enabled per account. - * address (Account address) => address (module address) => EnabledPluginInfo + * address (module address) => address (account address) => EnabledPluginInfo + * @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility. */ mapping(address => mapping(address => PluginAccessInfo)) public enabledPlugins; struct PluginAccessInfo { @@ -180,8 +181,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana if (!ISafeProtocolPlugin(plugin).supportsInterface(type(ISafeProtocolPlugin).interfaceId)) revert ContractDoesNotImplementValidInterfaceId(plugin); - PluginAccessInfo storage senderSentinelPlugin = enabledPlugins[msg.sender][SENTINEL_MODULES]; - PluginAccessInfo storage senderPlugin = enabledPlugins[msg.sender][plugin]; + PluginAccessInfo storage senderSentinelPlugin = enabledPlugins[SENTINEL_MODULES][msg.sender]; + PluginAccessInfo storage senderPlugin = enabledPlugins[plugin][msg.sender]; if (senderPlugin.nextPluginPointer != address(0)) { revert PluginAlreadyEnabled(msg.sender, plugin); @@ -208,8 +209,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana * @param plugin Plugin to be disabled */ function disablePlugin(address prevPlugin, address plugin) external noZeroOrSentinelPlugin(plugin) onlyAccount { - PluginAccessInfo storage prevPluginInfo = enabledPlugins[msg.sender][prevPlugin]; - PluginAccessInfo storage pluginInfo = enabledPlugins[msg.sender][plugin]; + PluginAccessInfo storage prevPluginInfo = enabledPlugins[prevPlugin][msg.sender]; + PluginAccessInfo storage pluginInfo = enabledPlugins[plugin][msg.sender]; if (prevPluginInfo.nextPluginPointer != plugin) { revert InvalidPrevPluginAddress(prevPlugin); @@ -229,7 +230,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana * @param plugin Address of a plugin */ function getPluginInfo(address account, address plugin) external view returns (PluginAccessInfo memory enabled) { - return enabledPlugins[account][plugin]; + return enabledPlugins[plugin][account]; } /** @@ -239,7 +240,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana * @return True if the plugin is enabled */ function isPluginEnabled(address account, address plugin) public view returns (bool) { - return SENTINEL_MODULES != plugin && enabledPlugins[account][plugin].nextPluginPointer != address(0); + return SENTINEL_MODULES != plugin && enabledPlugins[plugin][account].nextPluginPointer != address(0); } /** @@ -268,10 +269,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana // Populate return array uint256 pluginCount = 0; - next = enabledPlugins[account][start].nextPluginPointer; + next = enabledPlugins[start][account].nextPluginPointer; while (next != address(0) && next != SENTINEL_MODULES && pluginCount < pageSize) { array[pluginCount] = next; - next = enabledPlugins[account][next].nextPluginPointer; + next = enabledPlugins[next][account].nextPluginPointer; pluginCount++; } @@ -282,10 +283,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana /** Because of the argument validation, we can assume that the loop will always iterate over the valid plugin list values - and the `next` variable will either be an enabled plugin or a sentinel address (signalling the end). - + and the `next` variable will either be an enabled plugin or a sentinel address (signalling the end). + If we haven't reached the end inside the loop, we need to set the next pointer to the last element of the plugins array - because the `next` variable (which is a plugin by itself) acting as a pointer to the start of the next page is neither + because the `next` variable (which is a plugin by itself) acting as a pointer to the start of the next page is neither included to the current page, nor will it be included in the next one if you pass it as a start. */ if (next != SENTINEL_MODULES && pluginCount != 0) { @@ -436,7 +437,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana } function checkOnlyEnabledPlugin(address account) private view { - if (enabledPlugins[account][msg.sender].nextPluginPointer == address(0)) { + if (enabledPlugins[msg.sender][account].nextPluginPointer == address(0)) { revert PluginNotEnabled(msg.sender); } } @@ -457,7 +458,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana */ function checkPermission(address account, uint8 permission) private view { // For each action, Manager will read storage and call plugin's requiresPermissions(). - uint8 givenPermissions = enabledPlugins[account][msg.sender].permissions; + uint8 givenPermissions = enabledPlugins[msg.sender][account].permissions; uint8 requiresPermissions = ISafeProtocolPlugin(msg.sender).requiresPermissions(); if ((requiresPermissions & givenPermissions & permission) != permission) { diff --git a/contracts/SignatureValidatorManager.sol b/contracts/SignatureValidatorManager.sol index 9f6cdcd7..114bb44b 100644 --- a/contracts/SignatureValidatorManager.sol +++ b/contracts/SignatureValidatorManager.sol @@ -37,9 +37,10 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand // Storage /** - * @notice Mapping to account address => domain separator => signature validator contract + * @notice Mapping to domain separator => account address => signature validator contract + * @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility. */ - mapping(address => mapping(bytes32 => address)) public signatureValidators; + mapping(bytes32 => mapping(address => address)) public signatureValidators; /** * @notice Mapping to account address => signature validator hooks contract @@ -72,7 +73,7 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand if (!ISafeProtocolSignatureValidator(signatureValidator).supportsInterface(type(ISafeProtocolSignatureValidator).interfaceId)) revert ContractDoesNotImplementValidInterfaceId(signatureValidator); } - signatureValidators[msg.sender][domainSeparator] = signatureValidator; + signatureValidators[domainSeparator][msg.sender] = signatureValidator; emit SignatureValidatorChanged(msg.sender, domainSeparator, signatureValidator); } @@ -193,7 +194,7 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand revert InvalidMessageHash(messageHash); } - address signatureValidator = signatureValidators[account][domainSeparator]; + address signatureValidator = signatureValidators[domainSeparator][account]; if (signatureValidator == address(0)) { revert SignatureValidatorNotSet(account); } diff --git a/contracts/base/FunctionHandlerManager.sol b/contracts/base/FunctionHandlerManager.sol index 984565a6..e3489ef9 100644 --- a/contracts/base/FunctionHandlerManager.sol +++ b/contracts/base/FunctionHandlerManager.sol @@ -14,9 +14,11 @@ import {MODULE_TYPE_FUNCTION_HANDLER} from "../common/Constants.sol"; */ abstract contract FunctionHandlerManager is RegistryManager { // Storage - /** @dev Mapping that stores information about an account, function selector, and address of the account. + /** + * @notice Mapping that stores information about an account, function selector, and address of the account. + * @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility. */ - mapping(address => mapping(bytes4 => address)) public functionHandlers; + mapping(bytes4 => mapping(address => address)) public functionHandlers; // Events event FunctionHandlerChanged(address indexed account, bytes4 indexed selector, address indexed functionHandler); @@ -31,7 +33,7 @@ abstract contract FunctionHandlerManager is RegistryManager { * @return functionHandler Address of the contract to be set as a function handler */ function getFunctionHandler(address account, bytes4 selector) external view returns (address functionHandler) { - functionHandler = functionHandlers[account][selector]; + functionHandler = functionHandlers[selector][account]; } /** @@ -48,7 +50,7 @@ abstract contract FunctionHandlerManager is RegistryManager { } // No need to check if functionHandler implements expected interfaceId as check will be done when adding to registry. - functionHandlers[msg.sender][selector] = functionHandler; + functionHandlers[selector][msg.sender] = functionHandler; emit FunctionHandlerChanged(msg.sender, selector, functionHandler); } @@ -63,7 +65,7 @@ abstract contract FunctionHandlerManager is RegistryManager { address account = msg.sender; bytes4 functionSelector = bytes4(msg.data); - address functionHandler = functionHandlers[account][functionSelector]; + address functionHandler = functionHandlers[functionSelector][account]; // Revert if functionHandler is not set if (functionHandler == address(0)) { diff --git a/docs/execution_flows.md b/docs/execution_flows.md index 4869f67a..84aa0c43 100644 --- a/docs/execution_flows.md +++ b/docs/execution_flows.md @@ -43,7 +43,7 @@ end subgraph SafeProtocolManager ExamplePlugin1 -->|Execute tx for an Account through Plugin| Execute_Transaction(Execute transaction from a Plugin) --> Validate_ExecuteFromPluginFlow{Is Plugin Enabled?
Call SafeProtocolRegistry
and validate if Plugin trusted} - Validate_ExecuteFromPluginFlow -- No ----> E(Revert transaction) + Validate_ExecuteFromPluginFlow -- No ----> E(Revert transaction) end ``` @@ -118,4 +118,4 @@ SafeProtocolManager --> isValidSignature{isValidSignature} isValidSignature --> |Yes| ExecuteTx(Continue transaction execution) User("`Users(s)`") --> |Generate an Account signature| Sign_Transaction -``` \ No newline at end of file +``` diff --git a/src/tasks/generate_deployments_markdown.ts b/src/tasks/generate_deployments_markdown.ts index e358c7e9..68f3d4b3 100644 --- a/src/tasks/generate_deployments_markdown.ts +++ b/src/tasks/generate_deployments_markdown.ts @@ -11,7 +11,7 @@ task("generate:deployments", "Generate markdown file with deployed contract addr console.error("No deployments file found. Please run the deployment script first."); return; } - + const {default: deployments} = await import("../../deployments"); const markdownFile = "./docs/deployments.md"; diff --git a/src/utils/constants.ts b/src/utils/constants.ts index 0c4cfc52..4c9b2d0e 100644 --- a/src/utils/constants.ts +++ b/src/utils/constants.ts @@ -12,4 +12,4 @@ export const MODULE_TYPE_SIGNATURE_VALIDATOR: number = 16; // solidity: bytes4(keccak256("Account712Signature(bytes32,bytes32,bytes)")); // javascript: hre.ethers.keccak256(toUtf8Bytes("Account712Signature(bytes32,bytes32,bytes)")).slice(0, 10); -export const SIGNATURE_VALIDATOR_SELECTOR = "0xb5c726cb"; +export const SIGNATURE_VALIDATOR_SELECTOR = "0xb5c726cb"; diff --git a/test/SignatureValidatorManager.spec.ts b/test/SignatureValidatorManager.spec.ts index d6ccc62b..d38896c4 100644 --- a/test/SignatureValidatorManager.spec.ts +++ b/test/SignatureValidatorManager.spec.ts @@ -79,7 +79,7 @@ describe("SignatureValidatorManager", () => { await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidator, MaxUint256); - expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal( + expect(await safeProtocolSignatureValidatorManager.signatureValidators(domainSeparator, account.target)).to.be.equal( mockContract.target, ); @@ -89,7 +89,7 @@ describe("SignatureValidatorManager", () => { ]); await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataResetValidator, MaxUint256); - expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal(ZeroAddress); + expect(await safeProtocolSignatureValidatorManager.signatureValidators(domainSeparator, account.target)).to.be.equal(ZeroAddress); }); it("should revert when enabling a signature validator hooks not implementing ISafeProtocolSignatureValidatorHooks interface", async () => {