Skip to content

Commit

Permalink
dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mudgen committed Sep 25, 2020
1 parent c5e2794 commit a98a98d
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 96 deletions.
46 changes: 8 additions & 38 deletions contracts/Diamond.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,25 @@ pragma experimental ABIEncoderV2;
import "./libraries/LibDiamondStorage.sol";
import "./libraries/LibDiamondCut.sol";
import "./facets/OwnershipFacet.sol";
import "./facets/DiamondLoupeFacet.sol";
import "./facets/DiamondCutFacet.sol";
import "./interfaces/IDiamondCut.sol";
import "./facets/DiamondLoupeFacet.sol";

contract Diamond {
event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);

constructor(address owner) payable {
LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage();
ds.contractOwner = owner;
emit OwnershipTransferred(address(0), owner);

DiamondCutFacet diamondCutFacet = new DiamondCutFacet();

DiamondLoupeFacet diamondLoupeFacet = new DiamondLoupeFacet();

OwnershipFacet ownershipFacet = new OwnershipFacet();

IDiamondCut.Facet[] memory diamondCut = new IDiamondCut.Facet[](3);

// adding diamondCut function
diamondCut[0].facetAddress = address(diamondCutFacet);
diamondCut[0].functionSelectors = new bytes4[](1);
diamondCut[0].functionSelectors[0] = DiamondCutFacet.diamondCut.selector;

// adding diamond loupe functions
diamondCut[1].facetAddress = address(diamondLoupeFacet);
diamondCut[1].functionSelectors = new bytes4[](5);
diamondCut[1].functionSelectors[0] = DiamondLoupeFacet.facetFunctionSelectors.selector;
diamondCut[1].functionSelectors[1] = DiamondLoupeFacet.facets.selector;
diamondCut[1].functionSelectors[2] = DiamondLoupeFacet.facetAddress.selector;
diamondCut[1].functionSelectors[3] = DiamondLoupeFacet.facetAddresses.selector;
diamondCut[1].functionSelectors[4] = DiamondLoupeFacet.supportsInterface.selector;

// adding ownership functions
diamondCut[2].facetAddress = address(ownershipFacet);
diamondCut[2].functionSelectors = new bytes4[](2);
diamondCut[2].functionSelectors[0] = OwnershipFacet.transferOwnership.selector;
diamondCut[2].functionSelectors[1] = OwnershipFacet.owner.selector;

// execute non-standard internal diamondCut function to add functions
LibDiamondCut.diamondCut(diamondCut);
constructor(address _owner, IDiamondCut.FacetCut[] memory diamondCut) payable {
LibDiamondCut.diamondCut(diamondCut, address(0), new bytes(0));

LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage();
ds.contractOwner = _owner;
emit OwnershipTransferred(address(0), _owner);

// adding ERC165 data
// ERC165
ds.supportedInterfaces[IERC165.supportsInterface.selector] = true;

// DiamondCut
ds.supportedInterfaces[IDiamondCut.diamondCut.selector] = true;
ds.supportedInterfaces[DiamondCutFacet.diamondCut.selector] = true;

// DiamondLoupe
bytes4 interfaceID = IDiamondLoupe.facets.selector ^
Expand Down
22 changes: 4 additions & 18 deletions contracts/facets/DiamondCutFacet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,23 @@ contract DiamondCutFacet is IDiamondCut {
/// @param _calldata A function call, including function selector and arguments
/// _calldata is executed with delegatecall on _init
function diamondCut(
Facet[] calldata _diamondCut,
FacetCut[] calldata _diamondCut,
address _init,
bytes calldata _calldata
) external override {
LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage();
require(msg.sender == ds.contractOwner, "DiamondCutFacet: Must own the contract");
require(_diamondCut.length > 0, "DiamondCutFacet: No facets to cut");
uint256 selectorCount = ds.selectors.length;
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
selectorCount = LibDiamondCut.addReplaceRemoveFacetSelectors(
selectorCount,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
}
emit DiamondCut(_diamondCut, _init, _calldata);
if (_init == address(0)) {
require(_calldata.length == 0, "DiamondCutFacet: _init is address(0) but_calldata is not empty");
} else {
require(_calldata.length > 0, "DiamondCutFacet: _calldata is empty but _init is not address(0)");
if (_init != address(this)) {
LibDiamondCut.hasContractCode(_init, "DiamondCutFacet: _init address has no code");
}
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up the error
revert(string(error));
} else {
revert("DiamondCutFacet: _init function reverted");
}
}
}
LibDiamondCut.initializeDiamondCut(_init, _calldata);
}
}
13 changes: 10 additions & 3 deletions contracts/interfaces/IDiamondCut.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
pragma solidity ^0.7.1;
pragma experimental ABIEncoderV2;

/******************************************************************************\
* Author: Nick Mudge <[email protected]> (https://twitter.com/mudgen)
/******************************************************************************/

interface IDiamondCut {
struct Facet {
enum FacetCutAction {Add, Replace, Remove}

struct FacetCut {
address facetAddress;
FacetCutAction action;
bytes4[] functionSelectors;
}

Expand All @@ -15,10 +22,10 @@ interface IDiamondCut {
/// @param _calldata A function call, including function selector and arguments
/// _calldata is executed with delegatecall on _init
function diamondCut(
Facet[] calldata _diamondCut,
FacetCut[] calldata _diamondCut,
address _init,
bytes calldata _calldata
) external;

event DiamondCut(Facet[] _diamondCut, address _init, bytes _calldata);
event DiamondCut(FacetCut[] _diamondCut, address _init, bytes _calldata);
}
62 changes: 47 additions & 15 deletions contracts/libraries/LibDiamondCut.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,66 +12,78 @@ import "./LibDiamondStorage.sol";
import "../interfaces/IDiamondCut.sol";

library LibDiamondCut {
event DiamondCut(IDiamondCut.Facet[] _diamondCut, address _init, bytes _calldata);
event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);

// Non-standard internal function version of diamondCut
// Internal function version of diamondCut
// This code is almost the same as the external diamondCut,
// except it is using 'Facet[] memory _diamondCut' instead of
// 'Facet[] calldata _diamondCut'.
// The code is duplicated to prevent copying calldata to memory which
// causes an error for a two dimensional array.
function diamondCut(IDiamondCut.Facet[] memory _diamondCut) internal {
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
require(_diamondCut.length > 0, "LibDiamondCut: No facets to cut");
LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage();
uint256 selectorCount = ds.selectors.length;
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
selectorCount = addReplaceRemoveFacetSelectors(
selectorCount,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
}
emit DiamondCut(_diamondCut, address(0), new bytes(0));
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}

function addReplaceRemoveFacetSelectors(
uint256 _selectorCount,
address _newFacetAddress,
IDiamondCut.FacetCutAction _action,
bytes4[] memory _selectors
) internal returns (uint256) {
LibDiamondStorage.DiamondStorage storage ds = LibDiamondStorage.diamondStorage();
require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
// add or replace functions
if (_newFacetAddress != address(0)) {
hasContractCode(_newFacetAddress, "LibDiamondCut: facet has no code");
// add and replace selectors
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// add
if (oldFacetAddress == address(0)) {
if (_action == IDiamondCut.FacetCutAction.Add) {
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
ds.facetAddressAndSelectorPosition[selector] = LibDiamondStorage.FacetAddressAndSelectorPosition(
_newFacetAddress,
uint16(_selectorCount)
);
ds.selectors.push(selector);
_selectorCount++;
} else {
} else if (_action == IDiamondCut.FacetCutAction.Replace) {
// replace
// only useful if immutable functions exist
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
if (oldFacetAddress != _newFacetAddress) {
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress;
}
require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function");
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress;
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
}
} else {
require(_action == IDiamondCut.FacetCutAction.Remove, "LibDiamondCut: action not set to FacetCutAction.Remove");
// remove functions
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
LibDiamondStorage.FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds
.facetAddressAndSelectorPosition[selector];
// if selector already does not exist then do nothing
if (oldFacetAddressAndSelectorPosition.facetAddress == address(0)) {
continue;
}
require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// only useful if immutable functions exist
require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function.");
bytes4 lastSelector = ds.selectors[_selectorCount - 1];
// replace selector with last selector
Expand All @@ -88,6 +100,26 @@ library LibDiamondCut {
return _selectorCount;
}

function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty");
} else {
require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)");
if (_init != address(this)) {
LibDiamondCut.hasContractCode(_init, "LibDiamondCut: _init address has no code");
}
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up the error
revert(string(error));
} else {
revert("LibDiamondCut: _init function reverted");
}
}
}
}

function hasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
Expand Down
3 changes: 3 additions & 0 deletions migrations/1_initial_migration.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/* eslint-disable prefer-const */
/* global artifacts */

const Migrations = artifacts.require('Migrations')

module.exports = function (deployer) {
Expand Down
38 changes: 35 additions & 3 deletions migrations/2_diamond.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
/* eslint-disable prefer-const */
/* global artifacts */

const Diamond = artifacts.require('Diamond')
const DiamondCutFacet = artifacts.require('DiamondCutFacet')
const DiamondLoupeFacet = artifacts.require('DiamondLoupeFacet')
const OwnershipFacet = artifacts.require('OwnershipFacet')
const Test1Facet = artifacts.require('Test1Facet')
const Test2Facet = artifacts.require('Test2Facet')

const FacetCutAction = {
Add: 0,
Replace: 1,
Remove: 2
}

function getSelectors (contract) {
const selectors = contract.abi.reduce((acc, val) => {
if (val.type === 'function') {
acc.push(val.signature)
return acc
} else {
return acc
}
}, [])
return selectors
}

module.exports = function (deployer, network, accounts) {
// deployment steps
// The constructor inside Diamond deploys DiamondFacet
deployer.deploy(Diamond, accounts[0])
deployer.deploy(Test1Facet)
deployer.deploy(Test2Facet)

deployer.deploy(DiamondCutFacet)
deployer.deploy(DiamondLoupeFacet)
deployer.deploy(OwnershipFacet).then(() => {
const diamondCut = [
[DiamondCutFacet.address, FacetCutAction.Add, getSelectors(DiamondCutFacet)],
[DiamondLoupeFacet.address, FacetCutAction.Add, getSelectors(DiamondLoupeFacet)],
[OwnershipFacet.address, FacetCutAction.Add, getSelectors(OwnershipFacet)]
]
return deployer.deploy(Diamond, accounts[0], diamondCut)
})
}
10 changes: 8 additions & 2 deletions test/cacheBugTest.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ const DiamondCutFacet = artifacts.require('DiamondCutFacet')
const DiamondLoupeFacet = artifacts.require('DiamondLoupeFacet')
const Test1Facet = artifacts.require('Test1Facet')

const FacetCutAction = {
Add: 0,
Replace: 1,
Remove: 2
}

// The diamond example comes with 8 function selectors
// [cut, loupe, loupe, loupe, loupe, erc165, transferOwnership, owner]
// This bug manifests if you delete something from the final
Expand Down Expand Up @@ -53,7 +59,7 @@ contract('Cache bug test', async accounts => {
web3.eth.defaultAccount = accounts[0]

// Add functions
await diamondCutFacet.methods.diamondCut([[test1Facet.address, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 })
await diamondCutFacet.methods.diamondCut([[test1Facet.address, FacetCutAction.Add, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 })

// Remove function selectors
// Function selector for the owner function in slot 0
Expand All @@ -62,7 +68,7 @@ contract('Cache bug test', async accounts => {
sel5,
sel10
]
await diamondCutFacet.methods.diamondCut([[zeroAddress, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 })
await diamondCutFacet.methods.diamondCut([[zeroAddress, FacetCutAction.Remove, selectors]], zeroAddress, '0x').send({ from: web3.eth.defaultAccount, gas: 1000000 })
})

it('should not exhibit the cache bug', async () => {
Expand Down
Loading

0 comments on commit a98a98d

Please sign in to comment.