From 312b96d05a292e686d5c2e66263acef34df95859 Mon Sep 17 00:00:00 2001 From: dbeal Date: Tue, 10 Sep 2024 08:24:14 +0900 Subject: [PATCH] changes to improve performance of erc7412 required from upstream (#62) * changes to improve performance of erc7412 required from upstream * update error handling --- src/synthetix/perps/perps.py | 1 + src/synthetix/utils/multicall.py | 60 ++++++++++++---- src/tests/test_oracles.py | 116 +++++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 13 deletions(-) create mode 100644 src/tests/test_oracles.py diff --git a/src/synthetix/perps/perps.py b/src/synthetix/perps/perps.py index a5fa4de..989053d 100644 --- a/src/synthetix/perps/perps.py +++ b/src/synthetix/perps/perps.py @@ -237,6 +237,7 @@ def _prepare_oracle_call(self, market_names: [str] = []): self.snx, self.snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"], price_update_data, + 0, args, ) value = len(market_names) diff --git a/src/synthetix/utils/multicall.py b/src/synthetix/utils/multicall.py index 4a211dd..7019395 100644 --- a/src/synthetix/utils/multicall.py +++ b/src/synthetix/utils/multicall.py @@ -7,7 +7,9 @@ # constants -ORACLE_DATA_REQUIRED = "0xcf2cabdf" +SELECTOR_ORACLE_DATA_REQUIRED = "0xcf2cabdf" +SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE = "0x0e7186fb" +SELECTOR_ERRORS = "0x0b42fd17" def decode_result(contract, function_name, result): @@ -20,14 +22,33 @@ def decode_result(contract, function_name, result): # ERC-7412 support -def decode_erc7412_error(snx, error): +def decode_erc7412_errors_error(error): + """Decodes an Errors error""" + error_data = decode_hex(f"0x{error[10:]}") + + errors = decode(["bytes[]"], error_data)[0] + errors = [ContractCustomError(data=encode_hex(e)) for e in errors] + errors.reverse() + + return errors + + +def decode_erc7412_oracle_data_required_error(snx, error): """Decodes an OracleDataRequired error""" # remove the signature and decode the error data error_data = decode_hex(f"0x{error[10:]}") # decode the result - output_types = ["address", "bytes"] - address, data = decode(output_types, error_data) + # could be one of two types with different args + output_types = ["address", "bytes", "uint256"] + try: + address, data, fee = decode(output_types, error_data) + print("USED NORMAL output types") + except: + print("USING BACKUP output types") + address, data = decode(output_types[:2], error_data) + fee = 0 + address = snx.web3.to_checksum_address(address) # decode the bytes data into the arguments for the oracle @@ -41,7 +62,7 @@ def decode_erc7412_error(snx, error): ) feed_ids = [encode_hex(raw_feed_id) for raw_feed_id in raw_feed_ids] - return address, feed_ids, (update_type, staleness_tolerance, raw_feed_ids) + return address, feed_ids, fee, (update_type, staleness_tolerance, raw_feed_ids) except: pass @@ -51,14 +72,14 @@ def decode_erc7412_error(snx, error): feed_ids = [encode_hex(raw_feed_id)] raw_feed_ids = [raw_feed_id] - return address, feed_ids, (update_type, publish_time, raw_feed_ids) + return address, feed_ids, fee, (update_type, publish_time, raw_feed_ids) except: pass raise Exception("Error data can not be decoded") -def make_fulfillment_request(snx, address, price_update_data, args): +def make_fulfillment_request(snx, address, price_update_data, fee, args): erc_contract = snx.web3.eth.contract( address=address, abi=snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["abi"], @@ -71,7 +92,7 @@ def make_fulfillment_request(snx, address, price_update_data, args): ) # assume 1 wei per price update - value = len(price_update_data) * 1 + value = fee if fee > 0 else len(price_update_data) * 1 update_tx = erc_contract.functions.fulfillOracleQuery( encoded_args @@ -80,11 +101,24 @@ def make_fulfillment_request(snx, address, price_update_data, args): def handle_erc7412_error(snx, error, calls): - if type(error) is ContractCustomError and error.data.startswith( - ORACLE_DATA_REQUIRED + "When receiving a ERC7412 error, will return an updated list of calls with the required price updates" + if type(error) is ContractCustomError and error.data.startswith(SELECTOR_ERRORS): + errors = decode_erc7412_errors_error(error.data) + + # TODO: execute in parallel + for sub_error in errors: + sub_calls = handle_erc7412_error(snx, sub_error, []) + calls = sub_calls + calls + + return calls + if type(error) is ContractCustomError and ( + error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED) + or error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE) ): # decode error data - address, feed_ids, args = decode_erc7412_error(snx, error.data) + address, feed_ids, fee, args = decode_erc7412_oracle_data_required_error( + snx, error.data + ) update_type = args[0] if update_type == 1: @@ -106,7 +140,7 @@ def handle_erc7412_error(snx, error, calls): # create a new request to, data, value = make_fulfillment_request( - snx, address, price_update_data, args + snx, address, price_update_data, fee, args ) elif update_type == 2: # fetch the data from pyth for those feed ids @@ -115,7 +149,7 @@ def handle_erc7412_error(snx, error, calls): # create a new request to, data, value = make_fulfillment_request( - snx, address, price_update_data, args + snx, address, price_update_data, fee, args ) else: snx.logger.error(f"Unknown update type: {update_type}") diff --git a/src/tests/test_oracles.py b/src/tests/test_oracles.py new file mode 100644 index 0000000..9664429 --- /dev/null +++ b/src/tests/test_oracles.py @@ -0,0 +1,116 @@ +import os +from web3.exceptions import ContractCustomError +from eth_abi import decode, encode +from eth_utils import encode_hex, decode_hex +from synthetix import Synthetix +from synthetix.utils.multicall import ( + handle_erc7412_error, + SELECTOR_ERRORS, + SELECTOR_ORACLE_DATA_REQUIRED, + SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE, +) + +# constants +ODR_ERROR_TYPES = ["address", "bytes"] +ODR_FEE_ERROR_TYPES = ["address", "bytes", "uint256"] + +ODR_BYTES_TYPES = { + 1: ["uint8", "uint64", "bytes32[]"], + 2: ["uint8", "uint64", "bytes32"], +} + + +# encode some errors +def encode_odr_error(snx, inputs, with_fee=False): + "Utility to help encode errors to test" + types = ODR_FEE_ERROR_TYPES if with_fee else ODR_ERROR_TYPES + fee = [1] if with_fee else [] + address = snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"] + + # get the update type + update_type = inputs[0] + bytes_types = ODR_BYTES_TYPES[update_type] + + # encode bytes + error_bytes = encode(bytes_types, inputs) + + # encode the error + error = encode(types, [address, error_bytes] + fee) + return error + + +# tests + + +def test_update_type_1_with_staleness(snx): + # Test update_type 1 with staleness 3600 + feed_id = snx.pyth.price_feed_ids["ETH"] + error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)])) + error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex() + + custom_error = ContractCustomError(message="Test error", data=error_hex) + calls = handle_erc7412_error(snx, custom_error, []) + + assert len(calls) == 1, "Expected 1 call for update_type 1" + assert calls[0][1] == True, "Expected call to be marked as static" + assert calls[0][2] > 0, "Expected non-zero value for the call" + + +def test_update_type_2_with_recent_publish_time(snx): + # Test update_type 2 with a publish_time in the last 60 seconds + feed_id = snx.pyth.price_feed_ids["BTC"] + current_time = snx.web3.eth.get_block("latest").timestamp + recent_publish_time = current_time - 30 # 30 seconds ago + + error = encode_odr_error(snx, (2, recent_publish_time, decode_hex(feed_id))) + error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex() + + custom_error = ContractCustomError(message="Test error", data=error_hex) + calls = handle_erc7412_error(snx, custom_error, []) + + assert len(calls) == 1, "Expected 1 call for update_type 2" + assert calls[0][1] == True, "Expected call to be marked as static" + assert calls[0][2] > 0, "Expected non-zero value for the call" + + +def test_oracle_data_required_with_fee(snx): + # Test OracleDataRequired error with fee + feed_id = snx.pyth.price_feed_ids["ETH"] + error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)]), with_fee=True) + error_hex = SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE + error.hex() + + custom_error = ContractCustomError(message="Test error", data=error_hex) + calls = handle_erc7412_error(snx, custom_error, []) + + assert len(calls) == 1, "Expected 1 call for OracleDataRequired with fee" + assert calls[0][1] == True, "Expected call to be marked as static" + assert calls[0][2] > 0, "Expected non-zero value for the call" + + +def test_errors_with_multiple_sub_errors(snx): + # Test Errors error which includes multiple individual errors + feed_id_1 = snx.pyth.price_feed_ids["ETH"] + feed_id_2 = snx.pyth.price_feed_ids["BTC"] + + error_1 = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id_1)])) + error_1_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_1.hex() + + error_2 = encode_odr_error( + snx, (2, snx.web3.eth.get_block("latest").timestamp - 30, decode_hex(feed_id_2)) + ) + error_2_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_2.hex() + + # Encode multiple errors + errors_data = encode( + ["bytes[]"], [(decode_hex(error_1_hex), decode_hex(error_2_hex))] + ) + + errors_hex = SELECTOR_ERRORS + errors_data.hex() + + custom_error = ContractCustomError(message="Test error", data=errors_hex) + calls = handle_erc7412_error(snx, custom_error, []) + + assert len(calls) == 2, "Expected 2 calls for Errors with 2 sub-errors" + for call in calls: + assert call[1] == True, "Expected all calls to be marked as static" + assert call[2] > 0, "Expected non-zero value for all calls"