Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ntt: pack endpoint info in a single slot + replay protect messages #3709

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 192 additions & 34 deletions ethereum/contracts/native_token_transfer/EndpointManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,53 @@ contract EndpointManager is

uint64 sequence;
uint8 threshold;
mapping(address => bool) public isEndpoint;
address[] endpoints;

// ========================= ENDPOINT REGISTRATION =========================

// @dev Information about registered endpoints.
struct EndpointInfo {
// whether this endpoint is registered
bool registered;
// whether this endpoint is enabled
bool enabled;
uint8 index;
}

// @dev Information about registered endpoints.
// This is the source of truth, we define a couple of derived fields below
// for efficiency.
mapping(address => EndpointInfo) public endpointInfos;

// @dev List of enabled endpoints.
// invariant: forall (a: address), endpointInfos[a].enabled <=> a in enabledEndpoints
address[] enabledEndpoints;

// invariant: forall (i: uint8), enabledEndpointBitmap & i == 1 <=> endpointInfos[i].enabled
uint64 enabledEndpointBitmap;

uint8 constant MAX_ENDPOINTS = 64;

// @dev Total number of registered endpoints. This number can only increase.
// invariant: numRegisteredEndpoints <= MAX_ENDPOINTS
// invariant: forall (i: uint8),
// i < numRegisteredEndpoints <=> exists (a: address), endpointInfos[a].index == i
uint8 numRegisteredEndpoints;

// =========================================================================

// @dev Information about attestations for a given message.
struct AttestationInfo {
// bitmap of endpoints that have attested to this message (NOTE: might contain disabled endpoints)
uint64 attestedEndpoints;
// whether this message has been executed
bool executed;
}

// Maps are keyed by hash of EndpointManagerMessage.
mapping(bytes32 => mapping(address => bool))
public managerMessageAttestations;
mapping(bytes32 => uint8) public managerMessageAttestationCounts;
mapping(bytes32 => AttestationInfo) public managerMessageAttestations;

modifier onlyEndpoint() {
if (!isEndpoint[msg.sender]) {
if (!endpointInfos[msg.sender].enabled) {
revert CallerNotEndpoint(msg.sender);
}
_;
Expand All @@ -54,6 +91,7 @@ contract EndpointManager is
isLockingMode = _isLockingMode;
chainId = _chainId;
evmChainId = _evmChainId;
_checkEndpointsInvariants();
}

/// @notice Called by the user to send the token cross-chain.
Expand All @@ -66,9 +104,9 @@ contract EndpointManager is
) external payable nonReentrant returns (uint64 msgSequence) {
// check up front that msg.value will cover the delivery price
uint256 totalPriceQuote = 0;
uint256[] memory endpointQuotes = new uint256[](endpoints.length);
for (uint256 i = 0; i < endpoints.length; i++) {
uint256 endpointPriceQuote = IEndpoint(endpoints[i])
uint256[] memory endpointQuotes = new uint256[](enabledEndpoints.length);
for (uint256 i = 0; i < enabledEndpoints.length; i++) {
uint256 endpointPriceQuote = IEndpoint(enabledEndpoints[i])
.quoteDeliveryPrice(recipientChain);
endpointQuotes[i] = endpointPriceQuote;
totalPriceQuote += endpointPriceQuote;
Expand Down Expand Up @@ -144,8 +182,8 @@ contract EndpointManager is
);

// call into endpoint contracts to send the message
for (uint256 i = 0; i < endpoints.length; i++) {
IEndpoint(endpoints[i]).sendMessage{value: endpointQuotes[i]}(
for (uint256 i = 0; i < enabledEndpoints.length; i++) {
IEndpoint(enabledEndpoints[i]).sendMessage{value: endpointQuotes[i]}(
recipientChain,
encodedManagerPayload
);
Expand Down Expand Up @@ -175,6 +213,20 @@ contract EndpointManager is
return amount;
}

// @dev Mark a message as executed.
// This function will revert if the message has already been executed.
function _markMessageExecuted(
bytes32 digest
) internal {
// check if this message has already been executed
if (managerMessageAttestations[digest].executed) {
revert MessageAlreadyExecuted(digest);
}

// mark this message as executed
managerMessageAttestations[digest].executed = true;
}

/// @notice Called by a Endpoint contract to deliver a verified attestation.
/// This function will decode the payload as an EndpointManagerMessage to extract the sequence, msgType, and other parameters.
/// When the threshold is reached for a sequence, this function will execute logic to handle the action specified by the msgType and payload.
Expand All @@ -186,25 +238,24 @@ contract EndpointManager is

bytes32 managerMessageHash = computeManagerMessageHash(payload);

// if the attestation for this sender has already been received, revert
if (
managerMessageAttestations[managerMessageHash][msg.sender] == true
) {
revert MessageAttestationAlreadyReceived(managerMessageHash, msg.sender);
}

// add the Endpoint attestation for the sequence number
managerMessageAttestations[managerMessageHash][msg.sender] = true;
// set the attested flag for this endpoint.
// TODO: this allows an endpoint to attest to a message multiple times.
// This is fine, because attestation is idempotent (bitwise or 1), but
// maybe we want to revert anyway?
// TODO: factor out the bitmap logic into helper functions (or even a library)
managerMessageAttestations[managerMessageHash].attestedEndpoints |=
uint64(1 << endpointInfos[msg.sender].index);

// increment the attestations for the sequence
managerMessageAttestationCounts[managerMessageHash]++;
uint8 attestationCount = messageAttestations(managerMessageHash);

// end early if the threshold hasn't been met.
// otherwise, continue with execution for the message type.
if (managerMessageAttestationCounts[managerMessageHash] < threshold) {
if (attestationCount < threshold) {
return;
}

_markMessageExecuted(managerMessageHash);

// parse the payload as an EndpointManagerMessage
EndpointManagerMessage memory message = parseEndpointManagerMessage(
payload
Expand Down Expand Up @@ -259,7 +310,7 @@ contract EndpointManager is

/// @notice Returns the Endpoint contracts that have been registered via governance.
function getEndpoints() external view returns (address[] memory) {
return endpoints;
return enabledEndpoints;
}

function nextSequence() public view returns (uint64) {
Expand All @@ -277,41 +328,77 @@ contract EndpointManager is

function setThreshold(uint8 newThreshold) external onlyOwner {
threshold = newThreshold;
_checkEndpointsInvariants();
}

function setEndpoint(address endpoint) external onlyOwner {
if (endpoint == address(0)) {
revert InvalidEndpointZeroAddress();
}

if (isEndpoint[endpoint]) {
revert AlreadyRegisteredEndpoint(endpoint);
if (endpointInfos[endpoint].registered) {
endpointInfos[endpoint].enabled = true;
} else {
endpointInfos[endpoint] = EndpointInfo({
registered: true,
enabled: true,
index: numRegisteredEndpoints
});
numRegisteredEndpoints++;
}
isEndpoint[endpoint] = true;
endpoints.push(endpoint);

enabledEndpoints.push(endpoint);

uint64 updatedEnabledEndpointBitmap
= enabledEndpointBitmap | uint64(1 << endpointInfos[endpoint].index);
// ensure that this actually changed the bitmap
assert(updatedEnabledEndpointBitmap > enabledEndpointBitmap);
enabledEndpointBitmap = updatedEnabledEndpointBitmap;

emit EndpointAdded(endpoint);

_checkEndpointsInvariants();
}

function removeEndpoint(address endpoint) external onlyOwner {
if (endpoint == address(0)) {
revert InvalidEndpointZeroAddress();
}

if (!isEndpoint[endpoint]) {
if (!endpointInfos[endpoint].registered) {
revert NonRegisteredEndpoint(endpoint);
}

delete isEndpoint[endpoint];
if (!endpointInfos[endpoint].enabled) {
revert DisabledEndpoint(endpoint);
}

endpointInfos[endpoint].enabled = false;

for (uint256 i = 0; i < endpoints.length; i++) {
if (endpoints[i] == endpoint) {
endpoints[i] = endpoints[endpoints.length - 1];
endpoints.pop();
uint64 updatedEnabledEndpointBitmap
= enabledEndpointBitmap & uint64(~(1 << endpointInfos[endpoint].index));
// ensure that this actually changed the bitmap
assert(updatedEnabledEndpointBitmap < enabledEndpointBitmap);
enabledEndpointBitmap = updatedEnabledEndpointBitmap;

bool removed = false;

for (uint256 i = 0; i < enabledEndpoints.length; i++) {
if (enabledEndpoints[i] == endpoint) {
enabledEndpoints[i] = enabledEndpoints[enabledEndpoints.length - 1];
enabledEndpoints.pop();
removed = true;
break;
}
}
assert(removed);

emit EndpointRemoved(endpoint);

_checkEndpointsInvariants();
// we call the invariant check on the endpoint here as well, since
// the above check only iterates through the enabled endpoints.
_checkEndpointInvariants(endpoint);
}

function encodeEndpointManagerMessage(
Expand Down Expand Up @@ -392,4 +479,75 @@ contract EndpointManager is
) public pure returns (bytes32) {
return keccak256(payload);
}

// @dev Count the number of attestations from enabled endpoints for a given message.
function messageAttestations(
bytes32 managerMessageHash
) public view returns (uint8 count) {
uint64 attestedEndpoints = managerMessageAttestations[managerMessageHash].attestedEndpoints;

return countSetBits(attestedEndpoints & enabledEndpointBitmap);
}

// @dev Count the number of set bits in a uint64
function countSetBits(uint64 x) public pure returns (uint8 count) {
while (x != 0) {
x &= x - 1;
count++;
}

return count;
}

// @dev Check that the endpoint manager is in a valid state.
// Checking these invariants is somewhat costly, but we only need to do it
// when modifying the endpoints, which happens infrequently.
function _checkEndpointsInvariants() internal view {
// TODO: add custom errors for each invariant

for (uint256 i = 0; i < enabledEndpoints.length; i++) {
_checkEndpointInvariants(enabledEndpoints[i]);
}

// invariant: each endpoint is only enabled once
for (uint256 i = 0; i < enabledEndpoints.length; i++) {
for (uint256 j = i + 1; j < enabledEndpoints.length; j++) {
assert(enabledEndpoints[i] != enabledEndpoints[j]);
}
}

// invariant: numRegisteredEndpoints <= MAX_ENDPOINTS
assert(numRegisteredEndpoints <= MAX_ENDPOINTS);

// invariant: threshold <= enabledEndpoints.length
require(threshold <= enabledEndpoints.length, "threshold <= enabledEndpoints.length");
}

// @dev Check that the endpoint is in a valid state.
function _checkEndpointInvariants(address endpoint) internal view {
EndpointInfo memory endpointInfo = endpointInfos[endpoint];

// if an endpoint is not registered, it should not be enabled
assert(endpointInfo.registered || (!endpointInfo.enabled && endpointInfo.index == 0));

bool endpointInEnabledBitmap = (enabledEndpointBitmap & uint64(1 << endpointInfo.index)) != 0;
bool endpointEnabled = endpointInfo.enabled;

bool endpointInEnabledEndpoints = false;

for (uint256 i = 0; i < enabledEndpoints.length; i++) {
if (enabledEndpoints[i] == endpoint) {
endpointInEnabledEndpoints = true;
break;
}
}

// invariant: endpointInfos[endpoint].enabled <=> enabledEndpointBitmap & (1 << endpointInfos[endpoint].index) != 0
assert(endpointInEnabledBitmap == endpointEnabled);

// invariant: endpointInfos[endpoint].enabled <=> endpoint in enabledEndpoints
assert(endpointInEnabledEndpoints == endpointEnabled);

assert(endpointInfo.index < numRegisteredEndpoints);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ interface IEndpointManager {
uint256 providedPayment
);
error MessageAttestationAlreadyReceived(bytes32 msgHash, address endpoint);
error MessageAlreadyExecuted(bytes32 msgHash);
error UnexpectedEndpointManagerMessageType(uint8 msgType);
error InvalidTargetChain(uint16 targetChain, uint16 thisChain);
error InvalidEndpointZeroAddress();
error AlreadyRegisteredEndpoint(address endpoint);
error NonRegisteredEndpoint(address endpoint);
error DisabledEndpoint(address endpoint);
error TooManyEndpoints();
error InvalidFork(uint256 evmChainId, uint256 blockChainId);

event EndpointAdded(address endpoint);
Expand Down
14 changes: 14 additions & 0 deletions ethereum/forge-test/native-token-transfer/EndpointManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,21 @@ contract TestEndpointManager is Test {
Wormhole wormhole;
EndpointManager endpointManager;

function test_countSetBits() public {
assertEq(endpointManager.countSetBits(5), 2);
assertEq(endpointManager.countSetBits(0), 0);
assertEq(endpointManager.countSetBits(15), 4);
assertEq(endpointManager.countSetBits(16), 1);
assertEq(endpointManager.countSetBits(65535), 16);
}

function setUp() public {
endpointManager = new EndpointManagerContract(
address(0),
false,
0,
0
);
// deploy sample token contract
// deploy wormhole contract
// wormhole = deployWormholeForTest();
Expand Down