Skip to content

Commit

Permalink
changes to improve performance of erc7412 required from upstream (#62)
Browse files Browse the repository at this point in the history
* changes to improve performance of erc7412 required from upstream

* update error handling
  • Loading branch information
dbeal-eth authored Sep 9, 2024
1 parent e55819c commit 312b96d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/synthetix/perps/perps.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _prepare_oracle_call(self, market_names: [str] = []):
self.snx,
self.snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"],
price_update_data,
0,
args,
)
value = len(market_names)
Expand Down
60 changes: 47 additions & 13 deletions src/synthetix/utils/multicall.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


# constants
ORACLE_DATA_REQUIRED = "0xcf2cabdf"
SELECTOR_ORACLE_DATA_REQUIRED = "0xcf2cabdf"
SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE = "0x0e7186fb"
SELECTOR_ERRORS = "0x0b42fd17"


def decode_result(contract, function_name, result):
Expand All @@ -20,14 +22,33 @@ def decode_result(contract, function_name, result):


# ERC-7412 support
def decode_erc7412_error(snx, error):
def decode_erc7412_errors_error(error):
"""Decodes an Errors error"""
error_data = decode_hex(f"0x{error[10:]}")

errors = decode(["bytes[]"], error_data)[0]
errors = [ContractCustomError(data=encode_hex(e)) for e in errors]
errors.reverse()

return errors


def decode_erc7412_oracle_data_required_error(snx, error):
"""Decodes an OracleDataRequired error"""
# remove the signature and decode the error data
error_data = decode_hex(f"0x{error[10:]}")

# decode the result
output_types = ["address", "bytes"]
address, data = decode(output_types, error_data)
# could be one of two types with different args
output_types = ["address", "bytes", "uint256"]
try:
address, data, fee = decode(output_types, error_data)
print("USED NORMAL output types")
except:
print("USING BACKUP output types")
address, data = decode(output_types[:2], error_data)
fee = 0

address = snx.web3.to_checksum_address(address)

# decode the bytes data into the arguments for the oracle
Expand All @@ -41,7 +62,7 @@ def decode_erc7412_error(snx, error):
)

feed_ids = [encode_hex(raw_feed_id) for raw_feed_id in raw_feed_ids]
return address, feed_ids, (update_type, staleness_tolerance, raw_feed_ids)
return address, feed_ids, fee, (update_type, staleness_tolerance, raw_feed_ids)
except:
pass

Expand All @@ -51,14 +72,14 @@ def decode_erc7412_error(snx, error):

feed_ids = [encode_hex(raw_feed_id)]
raw_feed_ids = [raw_feed_id]
return address, feed_ids, (update_type, publish_time, raw_feed_ids)
return address, feed_ids, fee, (update_type, publish_time, raw_feed_ids)
except:
pass

raise Exception("Error data can not be decoded")


def make_fulfillment_request(snx, address, price_update_data, args):
def make_fulfillment_request(snx, address, price_update_data, fee, args):
erc_contract = snx.web3.eth.contract(
address=address,
abi=snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["abi"],
Expand All @@ -71,7 +92,7 @@ def make_fulfillment_request(snx, address, price_update_data, args):
)

# assume 1 wei per price update
value = len(price_update_data) * 1
value = fee if fee > 0 else len(price_update_data) * 1

update_tx = erc_contract.functions.fulfillOracleQuery(
encoded_args
Expand All @@ -80,11 +101,24 @@ def make_fulfillment_request(snx, address, price_update_data, args):


def handle_erc7412_error(snx, error, calls):
if type(error) is ContractCustomError and error.data.startswith(
ORACLE_DATA_REQUIRED
"When receiving a ERC7412 error, will return an updated list of calls with the required price updates"
if type(error) is ContractCustomError and error.data.startswith(SELECTOR_ERRORS):
errors = decode_erc7412_errors_error(error.data)

# TODO: execute in parallel
for sub_error in errors:
sub_calls = handle_erc7412_error(snx, sub_error, [])
calls = sub_calls + calls

return calls
if type(error) is ContractCustomError and (
error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED)
or error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE)
):
# decode error data
address, feed_ids, args = decode_erc7412_error(snx, error.data)
address, feed_ids, fee, args = decode_erc7412_oracle_data_required_error(
snx, error.data
)
update_type = args[0]

if update_type == 1:
Expand All @@ -106,7 +140,7 @@ def handle_erc7412_error(snx, error, calls):

# create a new request
to, data, value = make_fulfillment_request(
snx, address, price_update_data, args
snx, address, price_update_data, fee, args
)
elif update_type == 2:
# fetch the data from pyth for those feed ids
Expand All @@ -115,7 +149,7 @@ def handle_erc7412_error(snx, error, calls):

# create a new request
to, data, value = make_fulfillment_request(
snx, address, price_update_data, args
snx, address, price_update_data, fee, args
)
else:
snx.logger.error(f"Unknown update type: {update_type}")
Expand Down
116 changes: 116 additions & 0 deletions src/tests/test_oracles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from web3.exceptions import ContractCustomError
from eth_abi import decode, encode
from eth_utils import encode_hex, decode_hex
from synthetix import Synthetix
from synthetix.utils.multicall import (
handle_erc7412_error,
SELECTOR_ERRORS,
SELECTOR_ORACLE_DATA_REQUIRED,
SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE,
)

# constants
ODR_ERROR_TYPES = ["address", "bytes"]
ODR_FEE_ERROR_TYPES = ["address", "bytes", "uint256"]

ODR_BYTES_TYPES = {
1: ["uint8", "uint64", "bytes32[]"],
2: ["uint8", "uint64", "bytes32"],
}


# encode some errors
def encode_odr_error(snx, inputs, with_fee=False):
"Utility to help encode errors to test"
types = ODR_FEE_ERROR_TYPES if with_fee else ODR_ERROR_TYPES
fee = [1] if with_fee else []
address = snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"]

# get the update type
update_type = inputs[0]
bytes_types = ODR_BYTES_TYPES[update_type]

# encode bytes
error_bytes = encode(bytes_types, inputs)

# encode the error
error = encode(types, [address, error_bytes] + fee)
return error


# tests


def test_update_type_1_with_staleness(snx):
# Test update_type 1 with staleness 3600
feed_id = snx.pyth.price_feed_ids["ETH"]
error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)]))
error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for update_type 1"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_update_type_2_with_recent_publish_time(snx):
# Test update_type 2 with a publish_time in the last 60 seconds
feed_id = snx.pyth.price_feed_ids["BTC"]
current_time = snx.web3.eth.get_block("latest").timestamp
recent_publish_time = current_time - 30 # 30 seconds ago

error = encode_odr_error(snx, (2, recent_publish_time, decode_hex(feed_id)))
error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for update_type 2"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_oracle_data_required_with_fee(snx):
# Test OracleDataRequired error with fee
feed_id = snx.pyth.price_feed_ids["ETH"]
error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)]), with_fee=True)
error_hex = SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for OracleDataRequired with fee"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_errors_with_multiple_sub_errors(snx):
# Test Errors error which includes multiple individual errors
feed_id_1 = snx.pyth.price_feed_ids["ETH"]
feed_id_2 = snx.pyth.price_feed_ids["BTC"]

error_1 = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id_1)]))
error_1_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_1.hex()

error_2 = encode_odr_error(
snx, (2, snx.web3.eth.get_block("latest").timestamp - 30, decode_hex(feed_id_2))
)
error_2_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_2.hex()

# Encode multiple errors
errors_data = encode(
["bytes[]"], [(decode_hex(error_1_hex), decode_hex(error_2_hex))]
)

errors_hex = SELECTOR_ERRORS + errors_data.hex()

custom_error = ContractCustomError(message="Test error", data=errors_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 2, "Expected 2 calls for Errors with 2 sub-errors"
for call in calls:
assert call[1] == True, "Expected all calls to be marked as static"
assert call[2] > 0, "Expected non-zero value for all calls"

0 comments on commit 312b96d

Please sign in to comment.