Skip to content

Commit

Permalink
Improved add/replace/remove function code
Browse files Browse the repository at this point in the history
  • Loading branch information
mudgen committed Dec 19, 2020
1 parent b7b4ef9 commit 9e62923
Show file tree
Hide file tree
Showing 16 changed files with 833 additions and 631 deletions.
2 changes: 1 addition & 1 deletion contracts/Diamond.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand Down
2 changes: 1 addition & 1 deletion contracts/Migrations.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.0;
pragma solidity ^0.7.6;

contract Migrations {
address public owner;
Expand Down
16 changes: 3 additions & 13 deletions contracts/facets/DiamondCutFacet.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand All @@ -10,7 +10,7 @@ pragma experimental ABIEncoderV2;
import "../interfaces/IDiamondCut.sol";
import "../libraries/LibDiamond.sol";

contract DiamondCutFacet is IDiamondCut {
contract DiamondCutFacet is IDiamondCut {
/// @notice Add/replace/remove any number of functions and optionally execute
/// a function with delegatecall
/// @param _diamondCut Contains the facet addresses and function selectors
Expand All @@ -23,16 +23,6 @@ contract DiamondCutFacet is IDiamondCut {
bytes calldata _calldata
) external override {
LibDiamond.enforceIsContractOwner();
uint256 selectorCount = LibDiamond.diamondStorage().selectors.length;
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
selectorCount = LibDiamond.addReplaceRemoveFacetSelectors(
selectorCount,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
}
emit DiamondCut(_diamondCut, _init, _calldata);
LibDiamond.initializeDiamondCut(_init, _calldata);
LibDiamond.diamondCut(_diamondCut, _init, _calldata);
}
}
2 changes: 1 addition & 1 deletion contracts/facets/DiamondLoupeFacet.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand Down
2 changes: 1 addition & 1 deletion contracts/facets/OwnershipFacet.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;

import "../libraries/LibDiamond.sol";
import "../interfaces/IERC173.sol";
Expand Down
2 changes: 1 addition & 1 deletion contracts/facets/Test1Facet.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.0;
pragma solidity ^0.7.6;

contract Test1Facet {
event TestEvent(address something);
Expand Down
2 changes: 1 addition & 1 deletion contracts/facets/Test2Facet.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.0;
pragma solidity ^0.7.6;

contract Test2Facet {
function test2Func1() external {}
Expand Down
2 changes: 1 addition & 1 deletion contracts/interfaces/IDiamondCut.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand Down
2 changes: 1 addition & 1 deletion contracts/interfaces/IDiamondLoupe.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand Down
2 changes: 1 addition & 1 deletion contracts/interfaces/IERC165.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

interface IERC165 {
Expand Down
2 changes: 1 addition & 1 deletion contracts/interfaces/IERC173.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;

/// @title ERC-173 Contract Ownership Standard
/// Note: the ERC-165 identifier for this interface is 0x7f5828d0
Expand Down
144 changes: 67 additions & 77 deletions contracts/libraries/LibDiamond.sol
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.1;
pragma solidity ^0.7.6;
pragma experimental ABIEncoderV2;

/******************************************************************************\
Expand All @@ -16,10 +16,10 @@ library LibDiamond {
uint16 selectorPosition;
}

struct DiamondStorage {
struct DiamondStorage {
// function selector => facet address and selector position in selectors array
mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition;
bytes4[] selectors;
bytes4[] selectors;
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
Expand All @@ -45,99 +45,89 @@ library LibDiamond {
contractOwner_ = diamondStorage().contractOwner;
}

function enforceIsContractOwner() view internal {
function enforceIsContractOwner() internal view {
require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner");
}

modifier onlyOwner {
require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner");
_;
}

event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);

// Internal function version of diamondCut
// This code is almost the same as the external diamondCut,
// except it is using 'Facet[] memory _diamondCut' instead of
// 'Facet[] calldata _diamondCut'.
// The code is duplicated to prevent copying calldata to memory which
// causes an error for a two dimensional array.
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
uint256 selectorCount = diamondStorage().selectors.length;
) internal {
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
selectorCount = addReplaceRemoveFacetSelectors(
selectorCount,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action;
if (action == IDiamondCut.FacetCutAction.Add) {
addFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else if (action == IDiamondCut.FacetCutAction.Replace) {
replaceFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else if (action == IDiamondCut.FacetCutAction.Remove) {
removeFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}

function addReplaceRemoveFacetSelectors(
uint256 _selectorCount,
address _newFacetAddress,
IDiamondCut.FacetCutAction _action,
bytes4[] memory _selectors
) internal returns (uint256) {
function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_action == IDiamondCut.FacetCutAction.Add) {
require(_newFacetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(
_newFacetAddress,
uint16(_selectorCount)
);
ds.selectors.push(selector);
_selectorCount++;
}
} else if(_action == IDiamondCut.FacetCutAction.Replace) {
require(_newFacetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)");
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// only useful if immutable functions exist
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function");
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress;
}
} else if(_action == IDiamondCut.FacetCutAction.Remove) {
require(_newFacetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector];
require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// only useful if immutable functions exist
require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function.");
// replace selector with last selector
if (oldFacetAddressAndSelectorPosition.selectorPosition != _selectorCount - 1) {
bytes4 lastSelector = ds.selectors[_selectorCount - 1];
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
}
// delete last selector
ds.selectors.pop();
delete ds.facetAddressAndSelectorPosition[selector];
_selectorCount--;
uint16 selectorCount = uint16(ds.selectors.length);
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
enforceHasContractCode(_facetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(_facetAddress, selectorCount);
ds.selectors.push(selector);
selectorCount++;
}
}

function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
require(_facetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)");
enforceHasContractCode(_facetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// can't replace immutable functions -- functions defined directly in the diamond
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
require(oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function");
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _facetAddress;
}
}

function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
uint256 selectorCount = ds.selectors.length;
require(_facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector];
require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// can't remove immutable functions -- functions defined directly in the diamond
require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function.");
// replace selector with last selector
selectorCount--;
if (oldFacetAddressAndSelectorPosition.selectorPosition != selectorCount) {
bytes4 lastSelector = ds.selectors[selectorCount];
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
}
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
// delete last selector
ds.selectors.pop();
delete ds.facetAddressAndSelectorPosition[selector];
}
return _selectorCount;
}

function initializeDiamondCut(address _init, bytes memory _calldata) internal {
Expand Down
Loading

0 comments on commit 9e62923

Please sign in to comment.