Skip to content

Commit

Permalink
feat: Dynamically detect token storage slots when necessary
Browse files Browse the repository at this point in the history
Certain protocols include token contracts in their swap logic. Previously, this could cause balance mocks to fail by attempting to initialize an account that was already initialized with mock contract code, resulting in a silent no-op.

With this update, such cases are detected by inspecting the `involved_contracts` attribute. If a token's address is found in the involved contracts, the PoolState class will now dynamically identify the correct storage slots for balances and addresses. These slots will then be used for mocking balances in future operations, ensuring proper handling of such scenarios.
  • Loading branch information
kayibal committed Oct 17, 2024
1 parent 232b966 commit d2dbc26
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 10 deletions.
1 change: 1 addition & 0 deletions protosim_py/python/protosim_py/evm/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def decode_pool_state(
adapter_contract_path=self.adapter_contract,
trace=self.trace,
manual_updates=manual_updates,
involved_contracts=set(component.contract_ids),
**optional_attributes,
)

Expand Down
35 changes: 29 additions & 6 deletions protosim_py/python/protosim_py/evm/pool_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from eth_utils import keccak
from eth_typing import HexStr

from . import token
from . import SimulationEngine, AccountInfo, SimulationParameters
from .adapter_contract import AdapterContract
from .constants import MAX_BALANCE, EXTERNAL_ACCOUNT
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(
block_lasting_overwrites: defaultdict[Address, dict[int, int]] = None,
manual_updates: bool = False,
trace: bool = False,
involved_contracts=None,
token_storage_slots=None,
):
self.id_ = id_
"""The pools identifier."""
Expand Down Expand Up @@ -91,10 +94,22 @@ def __init__(
self.trace: bool = trace
"""If set, vm will emit detailed traces about the execution."""

self.involved_contracts: set[Address] = involved_contracts or set()
"""A set of all contract addresses involved in the simulation of this pool."""

self.token_storage_slots: dict[Address, tuple[int, int]] = (
token_storage_slots or {}
)
"""Allows the specification of custom storage slots for token allowances and
balances. This is particularly useful for token contracts involved in protocol
logic that extends beyond simple transfer functionality.
"""

self._engine: Optional[SimulationEngine] = None
self._set_engine()
self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine)
self._set_capabilities()
self._init_token_storage_slots()
if len(self.marginal_prices) == 0:
self._set_marginal_prices()

Expand All @@ -105,11 +120,6 @@ def _set_engine(self):
The engine will have the specified adapter contract mocked, as well as the
tokens used by the pool.
Parameters
----------
engine
Optional simulation engine instance.
"""
if self._engine is not None:
return
Expand Down Expand Up @@ -194,6 +204,16 @@ def _set_capabilities(self):
f"Pool {self.id_} hash different capabilities depending on the token pair!"
)

def _init_token_storage_slots(self):
for t in self.tokens:
if (
t.address in self.involved_contracts
and t.address not in self.token_storage_slots
):
self.token_storage_slots[t.address] = token.brute_force_slots(
t, self.block, self._engine
)

def get_amount_out(
self: TPoolState,
sell_token: EthereumToken,
Expand Down Expand Up @@ -291,8 +311,11 @@ def _get_token_overwrites(
def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
balance_overwrites = {}
address = self.balance_owner or self.id_
slots = (0, 1)
for t in self.tokens:
overwrites = ERC20OverwriteFactory(t)
if t.address in self.involved_contracts:
slots = self.token_storage_slots.get(t.address)
overwrites = ERC20OverwriteFactory(t, token_slots=slots)
overwrites.set_balance(
t.to_onchain_amount(self.balances[t.address]), address
)
Expand Down
96 changes: 96 additions & 0 deletions protosim_py/python/protosim_py/evm/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from .adapter_contract import ProtoSimContract
from .utils import ERC20OverwriteFactory
from .constants import EXTERNAL_ACCOUNT
from . import SimulationEngine
from ..models import EVMBlock, EthereumToken

_MARKER_VALUE = 314159265358979323846264338327950288419716939937510
_SPENDER = "0x08d967bb0134F2d07f7cfb6E246680c53927DD30"

class SlotDetectionFailure(Exception):
pass

def brute_force_slots(
t: EthereumToken, block: EVMBlock, engine: SimulationEngine
) -> tuple[int, int]:
"""Brute-force detection of storage slots for token allowances and balances.
This function attempts to determine the storage slots used by the token contract for
balance and allowance values by systematically testing different storage locations.
It uses EVM simulation to overwrite storage slots (from 0 to 19) and checks whether
the overwritten slot produces the expected result by making VM calls to
`balanceOf(...)` or `allowance(...)`.
The token contract and its storage must already be set up within the engine's
database before calling this function.
Parameters
----------
t : EthereumToken
The token whose storage slots are being brute-forced.
block : EVMBlock
The block at which the simulation is executed.
engine : SimulationEngine
The engine used to simulate the blockchain environment.
Returns
-------
tuple[int, int]
A tuple containing the detected balance storage slot and the allowance
storage slot, respectively.
Raises
------
SlotDetectionFailure
If the function fails to detect a valid slot for either balances or allowances
after checking all possible slots (0-19).
"""
token_contract = ProtoSimContract(t.address, "ERC20", engine)
balance_slot = None
for i in range(20):
overwrite_factory = ERC20OverwriteFactory(t, (i, 1))
overwrite_factory.set_balance(_MARKER_VALUE, EXTERNAL_ACCOUNT)
res = token_contract.call(
"balanceOf",
[EXTERNAL_ACCOUNT],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_protosim_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
balance_slot = i
break

allowance_slot = None
for i in range(20):
overwrite_factory = ERC20OverwriteFactory(t, (0, i))
overwrite_factory.set_allowance(_MARKER_VALUE, _SPENDER, EXTERNAL_ACCOUNT)
res = token_contract.call(
"allowance",
[EXTERNAL_ACCOUNT, _SPENDER],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_protosim_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
allowance_slot = i
break

if balance_slot is None:
raise SlotDetectionFailure(f"Failed to infer balance slot for {t.address}")

if allowance_slot is None:
raise SlotDetectionFailure(f"Failed to infer allowance slot for {t.address}")

return balance_slot, allowance_slot



8 changes: 4 additions & 4 deletions protosim_py/python/protosim_py/evm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def create_engine(


class ERC20OverwriteFactory:
def __init__(self, token: EthereumToken):
def __init__(self, token: EthereumToken, token_slots = (0, 1)):
"""
Initialize the ERC20OverwriteFactory.
Expand All @@ -72,8 +72,8 @@ def __init__(self, token: EthereumToken):
"""
self._token = token
self._overwrites = dict()
self._balance_slot: Final[int] = 0
self._allowance_slot: Final[int] = 1
self._balance_slot: int = token_slots[0]
self._allowance_slot: int = token_slots[1]
self._total_supply_slot: Final[int] = 2

def set_balance(self, balance: int, owner: Address):
Expand Down Expand Up @@ -377,4 +377,4 @@ def parse_account_info(accounts: list[dict[str, Any]]) -> list[AccountUpdate]:
)
)

return parsed
return parsed
Empty file.
124 changes: 124 additions & 0 deletions protosim_py/python/test/evm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from protosim_py.evm import AccountInfo, StateUpdate, BlockHeader, SimulationEngine
from protosim_py.evm.constants import MAX_BALANCE
from protosim_py.evm.utils import exec_rpc_method, get_code_for_address
from protosim_py.models import Address, EVMBlock


def read_account_storage_from_rpc(
address: Address, block_hash: str, connection_string: str = None
) -> dict[str, str]:
"""Reads complete storage of a contract from a Geth instance.
Parameters
----------
address:
The contracts address
block_hash:
The block hash at which we want to retrieve storage at.
connection_string:
The connection string for the Geth rpc endpoint.
Returns
-------
storage:
A dictionary containing the hex encoded slots (both keys and values).
"""

res = exec_rpc_method(
connection_string,
"debug_storageRangeAt",
[block_hash, 0, address, "0x00", 0x7FFFFFFF],
)

storage = {}
for i in res["storage"].values():
try:
if i["key"] is None:
raise RuntimeError(
"Node with preimages required, found a slot without key!"
)
k = i["key"]
if i["value"] is None:
continue
else:
v = i["value"]
storage[k] = v
except (TypeError, ValueError):
raise RuntimeError(
"Encountered invalid storage data retrieved data from geth -> " + str(i)
)
return storage


def init_contract_via_rpc(
block: EVMBlock,
contract_address: Address,
engine: SimulationEngine,
connection_string: str,
):
"""Initializes a contract in the simulation engine using data fetched via RPC.
This function retrieves the contract's bytecode and storage from an external RPC
endpoint and uses it to initialize the contract within the simulation engine.
Additionally, it sets up necessary default accounts and updates the contract's
state based on the provided block.
Parameters
----------
block :
The block at which to initialize the contract.
contract_address :
The address of the contract to be initialized.
engine :
The simulation engine instance where the contract is set up.
connection_string :
RPC connection string used to fetch contract data.
Returns
-------
SimulationEngine
The simulation engine with the contract initialized.
"""
bytecode = get_code_for_address(contract_address, connection_string)
storage = read_account_storage_from_rpc(
contract_address, block.hash_, connection_string
)
engine.init_account(
address="0x0000000000000000000000000000000000000000",
account=AccountInfo(balance=0, nonce=0),
mocked=False,
permanent_storage=None,
)
engine.init_account(
address="0x0000000000000000000000000000000000000004",
account=AccountInfo(balance=0, nonce=0),
mocked=False,
permanent_storage=None,
)
engine.init_account(
address=contract_address,
account=AccountInfo(
balance=MAX_BALANCE,
nonce=0,
code=bytecode,
),
mocked=False,
permanent_storage=None,
)
engine.update_state(
{
contract_address: StateUpdate(
storage={
int.from_bytes(
bytes.fromhex(k[2:]), "big", signed=False
): int.from_bytes(bytes.fromhex(v[2:]), "big", signed=False)
for k, v in storage.items()
},
balance=0,
)
},
BlockHeader(
number=block.id, hash=block.hash_, timestamp=int(block.ts.timestamp())
),
)
return engine
33 changes: 33 additions & 0 deletions protosim_py/python/test/test_evm_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

import pytest

from protosim_py.evm.storage import TychoDBSingleton
from protosim_py.evm.token import brute_force_slots
from protosim_py.evm.utils import (
create_engine,
)
from test.evm.utils import init_contract_via_rpc
from protosim_py.models import EthereumToken, EVMBlock

_ETH_RPC_URL = os.getenv("ETH_RPC_URL")


@pytest.mark.skipif(
_ETH_RPC_URL is None,
reason="Geth RPC access required. Please via `ETH_RPC_URL` env variable.",
)
def test_brute_force_slots():
block = EVMBlock(
20984206, "0x01a709ad31a9ff223f7932ae8f6d6762e02b114250393adf128a2858b39c4b9d"
)
token_address = "0xac3E018457B222d93114458476f3E3416Abbe38F"
token = EthereumToken("sFRAX", token_address, 18)
TychoDBSingleton.initialize()
engine = create_engine([], trace=True)
engine = init_contract_via_rpc(block, token_address, engine, _ETH_RPC_URL)

balance_slots, allowance_slot = brute_force_slots(token, block, engine)

assert balance_slots == 3
assert allowance_slot == 4

0 comments on commit d2dbc26

Please sign in to comment.