diff --git a/contracts/Diamond.sol b/contracts/Diamond.sol index b0715ba..3e07d31 100644 --- a/contracts/Diamond.sol +++ b/contracts/Diamond.sol @@ -11,55 +11,25 @@ pragma experimental ABIEncoderV2; import "./libraries/LibDiamondStorage.sol"; import "./libraries/LibDiamondCut.sol"; import "./facets/OwnershipFacet.sol"; -import "./facets/DiamondLoupeFacet.sol"; import "./facets/DiamondCutFacet.sol"; -import "./interfaces/IDiamondCut.sol"; +import "./facets/DiamondLoupeFacet.sol"; contract Diamond { event OwnershipTransferred(address indexed previousOwner, address indexed newOwner); - constructor(address owner) payable { - LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage(); - ds.contractOwner = owner; - emit OwnershipTransferred(address(0), owner); - - DiamondCutFacet diamondCutFacet = new DiamondCutFacet(); - - DiamondLoupeFacet diamondLoupeFacet = new DiamondLoupeFacet(); - - OwnershipFacet ownershipFacet = new OwnershipFacet(); - - IDiamondCut.Facet[] memory diamondCut = new IDiamondCut.Facet[](3); - - // adding diamondCut function - diamondCut[0].facetAddress = address(diamondCutFacet); - diamondCut[0].functionSelectors = new bytes4[](1); - diamondCut[0].functionSelectors[0] = DiamondCutFacet.diamondCut.selector; - - // adding diamond loupe functions - diamondCut[1].facetAddress = address(diamondLoupeFacet); - diamondCut[1].functionSelectors = new bytes4[](5); - diamondCut[1].functionSelectors[0] = DiamondLoupeFacet.facetFunctionSelectors.selector; - diamondCut[1].functionSelectors[1] = DiamondLoupeFacet.facets.selector; - diamondCut[1].functionSelectors[2] = DiamondLoupeFacet.facetAddress.selector; - diamondCut[1].functionSelectors[3] = DiamondLoupeFacet.facetAddresses.selector; - diamondCut[1].functionSelectors[4] = DiamondLoupeFacet.supportsInterface.selector; - - // adding ownership functions - diamondCut[2].facetAddress = address(ownershipFacet); - diamondCut[2].functionSelectors = new bytes4[](2); - diamondCut[2].functionSelectors[0] = OwnershipFacet.transferOwnership.selector; - diamondCut[2].functionSelectors[1] = OwnershipFacet.owner.selector; - - // execute non-standard internal diamondCut function to add functions - LibDiamondCut.diamondCut(diamondCut); + constructor(address _owner, IDiamondCut.FacetCut[] memory diamondCut) payable { + LibDiamondCut.diamondCut(diamondCut, address(0), new bytes(0)); + LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage(); + ds.contractOwner = _owner; + emit OwnershipTransferred(address(0), _owner); + // adding ERC165 data // ERC165 ds.supportedInterfaces[IERC165.supportsInterface.selector] = true; // DiamondCut - ds.supportedInterfaces[IDiamondCut.diamondCut.selector] = true; + ds.supportedInterfaces[DiamondCutFacet.diamondCut.selector] = true; // DiamondLoupe bytes4 interfaceID = IDiamondLoupe.facets.selector ^ diff --git a/contracts/facets/DiamondCutFacet.sol b/contracts/facets/DiamondCutFacet.sol index baf82a8..f26492c 100644 --- a/contracts/facets/DiamondCutFacet.sol +++ b/contracts/facets/DiamondCutFacet.sol @@ -19,37 +19,23 @@ contract DiamondCutFacet is IDiamondCut { /// @param _calldata A function call, including function selector and arguments /// _calldata is executed with delegatecall on _init function diamondCut( - Facet[] calldata _diamondCut, + FacetCut[] calldata _diamondCut, address _init, bytes calldata _calldata ) external override { LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage(); require(msg.sender == ds.contractOwner, "DiamondCutFacet: Must own the contract"); + require(_diamondCut.length > 0, "DiamondCutFacet: No facets to cut"); uint256 selectorCount = ds.selectors.length; for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) { selectorCount = LibDiamondCut.addReplaceRemoveFacetSelectors( selectorCount, _diamondCut[facetIndex].facetAddress, + _diamondCut[facetIndex].action, _diamondCut[facetIndex].functionSelectors ); } emit DiamondCut(_diamondCut, _init, _calldata); - if (_init == address(0)) { - require(_calldata.length == 0, "DiamondCutFacet: _init is address(0) but_calldata is not empty"); - } else { - require(_calldata.length > 0, "DiamondCutFacet: _calldata is empty but _init is not address(0)"); - if (_init != address(this)) { - LibDiamondCut.hasContractCode(_init, "DiamondCutFacet: _init address has no code"); - } - (bool success, bytes memory error) = _init.delegatecall(_calldata); - if (!success) { - if (error.length > 0) { - // bubble up the error - revert(string(error)); - } else { - revert("DiamondCutFacet: _init function reverted"); - } - } - } + LibDiamondCut.initializeDiamondCut(_init, _calldata); } } diff --git a/contracts/interfaces/IDiamondCut.sol b/contracts/interfaces/IDiamondCut.sol index 2474804..913d88e 100644 --- a/contracts/interfaces/IDiamondCut.sol +++ b/contracts/interfaces/IDiamondCut.sol @@ -2,9 +2,16 @@ pragma solidity ^0.7.1; pragma experimental ABIEncoderV2; +/******************************************************************************\ +* Author: Nick Mudge (https://twitter.com/mudgen) +/******************************************************************************/ + interface IDiamondCut { - struct Facet { + enum FacetCutAction {Add, Replace, Remove} + + struct FacetCut { address facetAddress; + FacetCutAction action; bytes4[] functionSelectors; } @@ -15,10 +22,10 @@ interface IDiamondCut { /// @param _calldata A function call, including function selector and arguments /// _calldata is executed with delegatecall on _init function diamondCut( - Facet[] calldata _diamondCut, + FacetCut[] calldata _diamondCut, address _init, bytes calldata _calldata ) external; - event DiamondCut(Facet[] _diamondCut, address _init, bytes _calldata); + event DiamondCut(FacetCut[] _diamondCut, address _init, bytes _calldata); } diff --git a/contracts/libraries/LibDiamondCut.sol b/contracts/libraries/LibDiamondCut.sol index 5cb7701..01b199d 100644 --- a/contracts/libraries/LibDiamondCut.sol +++ b/contracts/libraries/LibDiamondCut.sol @@ -12,66 +12,78 @@ import "./LibDiamondStorage.sol"; import "../interfaces/IDiamondCut.sol"; library LibDiamondCut { - event DiamondCut(IDiamondCut.Facet[] _diamondCut, address _init, bytes _calldata); + event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata); - // Non-standard internal function version of diamondCut + // Internal function version of diamondCut // This code is almost the same as the external diamondCut, // except it is using 'Facet[] memory _diamondCut' instead of // 'Facet[] calldata _diamondCut'. // The code is duplicated to prevent copying calldata to memory which // causes an error for a two dimensional array. - function diamondCut(IDiamondCut.Facet[] memory _diamondCut) internal { + function diamondCut( + IDiamondCut.FacetCut[] memory _diamondCut, + address _init, + bytes memory _calldata + ) internal { + require(_diamondCut.length > 0, "LibDiamondCut: No facets to cut"); LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage(); uint256 selectorCount = ds.selectors.length; for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) { selectorCount = addReplaceRemoveFacetSelectors( selectorCount, _diamondCut[facetIndex].facetAddress, + _diamondCut[facetIndex].action, _diamondCut[facetIndex].functionSelectors ); } - emit DiamondCut(_diamondCut, address(0), new bytes(0)); + emit DiamondCut(_diamondCut, _init, _calldata); + initializeDiamondCut(_init, _calldata); } function addReplaceRemoveFacetSelectors( uint256 _selectorCount, address _newFacetAddress, + IDiamondCut.FacetCutAction _action, bytes4[] memory _selectors ) internal returns (uint256) { LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage(); + require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); + // add or replace functions if (_newFacetAddress != address(0)) { hasContractCode(_newFacetAddress, "LibDiamondCut: facet has no code"); - // add and replace selectors for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) { bytes4 selector = _selectors[selectorIndex]; address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress; // add - if (oldFacetAddress == address(0)) { + if (_action == IDiamondCut.FacetCutAction.Add) { + require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists"); ds.facetAddressAndSelectorPosition[selector] = LibDiamondStorage.FacetAddressAndSelectorPosition( _newFacetAddress, uint16(_selectorCount) ); ds.selectors.push(selector); _selectorCount++; - } else { + } else if (_action == IDiamondCut.FacetCutAction.Replace) { // replace + // only useful if immutable functions exist require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function"); - if (oldFacetAddress != _newFacetAddress) { - // replace old facet address - ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress; - } + require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function"); + require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist"); + // replace old facet address + ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress; + } else { + revert("LibDiamondCut: Incorrect FacetCutAction"); } } } else { + require(_action == IDiamondCut.FacetCutAction.Remove, "LibDiamondCut: action not set to FacetCutAction.Remove"); // remove functions for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) { bytes4 selector = _selectors[selectorIndex]; LibDiamondStorage.FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds .facetAddressAndSelectorPosition[selector]; - // if selector already does not exist then do nothing - if (oldFacetAddressAndSelectorPosition.facetAddress == address(0)) { - continue; - } + require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist"); + // only useful if immutable functions exist require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function."); bytes4 lastSelector = ds.selectors[_selectorCount - 1]; // replace selector with last selector @@ -88,6 +100,26 @@ library LibDiamondCut { return _selectorCount; } + function initializeDiamondCut(address _init, bytes memory _calldata) internal { + if (_init == address(0)) { + require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty"); + } else { + require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)"); + if (_init != address(this)) { + LibDiamondCut.hasContractCode(_init, "LibDiamondCut: _init address has no code"); + } + (bool success, bytes memory error) = _init.delegatecall(_calldata); + if (!success) { + if (error.length > 0) { + // bubble up the error + revert(string(error)); + } else { + revert("LibDiamondCut: _init function reverted"); + } + } + } + } + function hasContractCode(address _contract, string memory _errorMessage) internal view { uint256 contractSize; assembly { diff --git a/migrations/1_initial_migration.js b/migrations/1_initial_migration.js index 4146b01..a548a72 100644 --- a/migrations/1_initial_migration.js +++ b/migrations/1_initial_migration.js @@ -1,3 +1,6 @@ +/* eslint-disable prefer-const */ +/* global artifacts */ + const Migrations = artifacts.require('Migrations') module.exports = function (deployer) { diff --git a/migrations/2_diamond.js b/migrations/2_diamond.js index 728656a..898f425 100644 --- a/migrations/2_diamond.js +++ b/migrations/2_diamond.js @@ -1,11 +1,43 @@ +/* eslint-disable prefer-const */ +/* global artifacts */ + const Diamond = artifacts.require('Diamond') +const DiamondCutFacet = artifacts.require('DiamondCutFacet') +const DiamondLoupeFacet = artifacts.require('DiamondLoupeFacet') +const OwnershipFacet = artifacts.require('OwnershipFacet') const Test1Facet = artifacts.require('Test1Facet') const Test2Facet = artifacts.require('Test2Facet') +const FacetCutAction = { + Add: 0, + Replace: 1, + Remove: 2 +} + +function getSelectors (contract) { + const selectors = contract.abi.reduce((acc, val) => { + if (val.type === 'function') { + acc.push(val.signature) + return acc + } else { + return acc + } + }, []) + return selectors +} + module.exports = function (deployer, network, accounts) { - // deployment steps - // The constructor inside Diamond deploys DiamondFacet - deployer.deploy(Diamond, accounts[0]) deployer.deploy(Test1Facet) deployer.deploy(Test2Facet) + + deployer.deploy(DiamondCutFacet) + deployer.deploy(DiamondLoupeFacet) + deployer.deploy(OwnershipFacet).then(() => { + const diamondCut = [ + [DiamondCutFacet.address, FacetCutAction.Add, getSelectors(DiamondCutFacet)], + [DiamondLoupeFacet.address, FacetCutAction.Add, getSelectors(DiamondLoupeFacet)], + [OwnershipFacet.address, FacetCutAction.Add, getSelectors(OwnershipFacet)] + ] + return deployer.deploy(Diamond, accounts[0], diamondCut) + }) } diff --git a/test/cacheBugTest.js b/test/cacheBugTest.js index d343895..4c9250e 100644 --- a/test/cacheBugTest.js +++ b/test/cacheBugTest.js @@ -6,6 +6,12 @@ const DiamondCutFacet = artifacts.require('DiamondCutFacet') const DiamondLoupeFacet = artifacts.require('DiamondLoupeFacet') const Test1Facet = artifacts.require('Test1Facet') +const FacetCutAction = { + Add: 0, + Replace: 1, + Remove: 2 +} + // The diamond example comes with 8 function selectors // [cut, loupe, loupe, loupe, loupe, erc165, transferOwnership, owner] // This bug manifests if you delete something from the final @@ -53,7 +59,7 @@ contract('Cache bug test', async accounts => { web3.eth.defaultAccount = accounts[0] // Add functions - await diamondCutFacet.methods.diamondCut([[test1Facet.address, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 }) + await diamondCutFacet.methods.diamondCut([[test1Facet.address, FacetCutAction.Add, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 }) // Remove function selectors // Function selector for the owner function in slot 0 @@ -62,7 +68,7 @@ contract('Cache bug test', async accounts => { sel5, sel10 ] - await diamondCutFacet.methods.diamondCut([[zeroAddress, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 }) + await diamondCutFacet.methods.diamondCut([[zeroAddress, FacetCutAction.Remove, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 }) }) it('should not exhibit the cache bug', async () => { diff --git a/test/diamondTest.js b/test/diamondTest.js index 878aaca..bfa6ac7 100644 --- a/test/diamondTest.js +++ b/test/diamondTest.js @@ -7,6 +7,11 @@ const DiamondLoupeFacet = artifacts.require('DiamondLoupeFacet') const OwnershipFacet = artifacts.require('OwnershipFacet') const Test1Facet = artifacts.require('Test1Facet') const Test2Facet = artifacts.require('Test2Facet') +const FacetCutAction = { + Add: 0, + Replace: 1, + Remove: 2 +} function getSelectors (contract) { const selectors = contract.abi.reduce((acc, val) => { @@ -36,6 +41,8 @@ function findPositionInFacets (facetAddress, facets) { contract('DiamondTest', async (accounts) => { let diamondCutFacet let diamondLoupeFacet + // eslint-disable-next-line no-unused-vars + let ownershipFacet let diamond let test1Facet let test2Facet @@ -53,7 +60,7 @@ contract('DiamondTest', async (accounts) => { // unfortunately this is done for the side affect of making selectors available in the ABI of // OwnershipFacet // eslint-disable-next-line no-unused-vars - const ownershipFacet = new web3.eth.Contract(OwnershipFacet.abi, diamond.address) + ownershipFacet = new web3.eth.Contract(OwnershipFacet.abi, diamond.address) web3.eth.defaultAccount = accounts[0] }) @@ -78,7 +85,7 @@ contract('DiamondTest', async (accounts) => { it('selectors should be associated to facets correctly -- multiple calls to facetAddress function', async () => { assert.equal( addresses[0], - await diamondLoupeFacet.methods.facetAddress('0xe712b4e1').call() + await diamondLoupeFacet.methods.facetAddress('0x1f931c1c').call() ) assert.equal( addresses[1], @@ -109,20 +116,29 @@ contract('DiamondTest', async (accounts) => { }) it('should add test1 functions', async () => { - let selectors = getSelectors(test1Facet) + let selectors = getSelectors(test1Facet).slice(0, -1) addresses.push(test1Facet.address) await diamondCutFacet.methods - .diamondCut([[test1Facet.address, selectors]], zeroAddress, '0x') + .diamondCut([[test1Facet.address, FacetCutAction.Add, selectors]], zeroAddress, '0x') .send({ from: web3.eth.defaultAccount, gas: 1000000 }) result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[3]).call() assert.sameMembers(result, selectors) }) + it('should replace test1 function', async () => { + let selectors = getSelectors(test1Facet).slice(-1) + await diamondCutFacet.methods + .diamondCut([[test1Facet.address, FacetCutAction.Replace, selectors]], zeroAddress, '0x') + .send({ from: web3.eth.defaultAccount, gas: 1000000 }) + result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[3]).call() + assert.sameMembers(result, getSelectors(test1Facet)) + }) + it('should add test2 functions', async () => { const selectors = getSelectors(test2Facet) addresses.push(test2Facet.address) await diamondCutFacet.methods - .diamondCut([[test2Facet.address, selectors]], zeroAddress, '0x') + .diamondCut([[test2Facet.address, FacetCutAction.Add, selectors]], zeroAddress, '0x') .send({ from: web3.eth.defaultAccount, gas: 1000000 }) result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[4]).call() assert.sameMembers(result, selectors) @@ -132,7 +148,7 @@ contract('DiamondTest', async (accounts) => { let selectors = getSelectors(test2Facet) let removeSelectors = [].concat(selectors.slice(0, 1), selectors.slice(4, 6), selectors.slice(-2)) result = await diamondCutFacet.methods - .diamondCut([[zeroAddress, removeSelectors]], zeroAddress, '0x') + .diamondCut([[zeroAddress, FacetCutAction.Remove, removeSelectors]], zeroAddress, '0x') .send({ from: web3.eth.defaultAccount, gas: 1000000 }) result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[4]).call() selectors = @@ -150,8 +166,8 @@ contract('DiamondTest', async (accounts) => { let removeSelectors = [].concat(selectors.slice(1, 2), selectors.slice(8, 10)) result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[3]).call() result = await diamondCutFacet.methods - .diamondCut([[zeroAddress, removeSelectors]], zeroAddress, '0x') - .send({ from: web3.eth.defaultAccount, gas: 7000000 }) + .diamondCut([[zeroAddress, FacetCutAction.Remove, removeSelectors]], zeroAddress, '0x') + .send({ from: web3.eth.defaultAccount, gas: 6000000 }) result = await diamondLoupeFacet.methods.facetFunctionSelectors(addresses[3]).call() selectors = [].concat(selectors.slice(0, 1), selectors.slice(2, 8), selectors.slice(10)) assert.sameMembers(result, selectors) @@ -167,25 +183,28 @@ contract('DiamondTest', async (accounts) => { removeItem(removeSelectors, '0x7a0ed627') result = await diamondCutFacet.methods - .diamondCut([[zeroAddress, removeSelectors]], zeroAddress, '0x') - .send({ from: web3.eth.defaultAccount, gas: 7000000 }) + .diamondCut([[zeroAddress, FacetCutAction.Remove, removeSelectors]], zeroAddress, '0x') + .send({ from: web3.eth.defaultAccount, gas: 6000000 }) facets = await diamondLoupeFacet.methods.facets().call() assert.equal(facets.length, 2) assert.equal(facets[0][0], addresses[0]) - assert.sameMembers(facets[0][1], ['0xe712b4e1']) + assert.sameMembers(facets[0][1], ['0x1f931c1c']) assert.equal(facets[1][0], addresses[1]) assert.sameMembers(facets[1][1], ['0x7a0ed627']) }) it('add most functions and facets', async () => { const diamondCut = [] - diamondCut.push([addresses[1], getSelectors(DiamondLoupeFacet)]) - diamondCut.push([addresses[2], getSelectors(OwnershipFacet)]) - diamondCut.push([addresses[3], getSelectors(test1Facet)]) - diamondCut.push([addresses[4], getSelectors(test2Facet)]) + const selectors = getSelectors(DiamondLoupeFacet) + removeItem(selectors, '0x7a0ed627') + selectors.pop() // remove supportsInterface which will be added later + diamondCut.push([addresses[1], FacetCutAction.Add, selectors]) + diamondCut.push([addresses[2], FacetCutAction.Add, getSelectors(OwnershipFacet)]) + diamondCut.push([addresses[3], FacetCutAction.Add, getSelectors(test1Facet)]) + diamondCut.push([addresses[4], FacetCutAction.Add, getSelectors(test2Facet)]) result = await diamondCutFacet.methods .diamondCut(diamondCut, zeroAddress, '0x') - .send({ from: web3.eth.defaultAccount, gas: 7000000 }) + .send({ from: web3.eth.defaultAccount, gas: 6000000 }) const facets = await diamondLoupeFacet.methods.facets().call() const facetAddresses = await diamondLoupeFacet.methods.facetAddresses().call() assert.equal(facetAddresses.length, 5) diff --git a/truffle-config.js b/truffle-config.js index f65dd3f..1e16676 100644 --- a/truffle-config.js +++ b/truffle-config.js @@ -79,7 +79,7 @@ module.exports = { // Set default mocha options here, use special reporters etc. mocha: { - // reporter: 'eth-gas-reporter' + reporter: 'eth-gas-reporter' // timeout: 100000 },