Skip to content

Commit

Permalink
Merge pull request #45 from propeller-heads/zz/protosim-py/handle-vyp…
Browse files Browse the repository at this point in the history
…er-contracts

feat(simulation-py): make token bruteforce compatible with vyper
  • Loading branch information
zizou0x authored Nov 7, 2024
2 parents 19c5012 + 6aedd01 commit 43552b9
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 57 deletions.
21 changes: 14 additions & 7 deletions tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ..exceptions import RecoverableSimulationException
from ..models import EVMBlock, Capability, Address, EthereumToken
from .utils import (
ContractCompiler,
ERC20Slots,
create_engine,
get_contract_bytecode,
frac_to_decimal,
Expand Down Expand Up @@ -97,12 +99,14 @@ def __init__(
self.involved_contracts: set[Address] = involved_contracts or set()
"""A set of all contract addresses involved in the simulation of this pool."""

self.token_storage_slots: dict[Address, tuple[int, int]] = (
token_storage_slots or {}
self.token_storage_slots: dict[Address, tuple[ERC20Slots, ContractCompiler]] = (
token_storage_slots or {}
)
"""Allows the specification of custom storage slots for token allowances and
balances. This is particularly useful for token contracts involved in protocol
logic that extends beyond simple transfer functionality.
Each entry also specify the compiler with which the target contract was compiled.
This is later used to compute storage slot for maps.
"""

self._engine: Optional[SimulationEngine] = None
Expand Down Expand Up @@ -177,7 +181,7 @@ def _set_marginal_prices(self):
t1,
[sell_amount],
block=self.block,
overwrites=self.block_lasting_overwrites,
overwrites=self._get_overwrites(t0,t1),
)[0]
if Capability.ScaledPrices in self.capabilities:
self.marginal_prices[(t0, t1)] = frac_to_decimal(frac)
Expand Down Expand Up @@ -298,9 +302,11 @@ def _get_token_overwrites(
max_amount = sell_token.to_onchain_amount(
self.get_sell_amount_limit(sell_token, buy_token)
)
slots, compiler = self.token_storage_slots.get(sell_token.address, (ERC20Slots(0, 1), ContractCompiler.Solidity))
overwrites = ERC20OverwriteFactory(
sell_token,
token_slots=self.token_storage_slots.get(sell_token.address, (0, 1)),
token_slots=slots,
compiler=compiler
)
overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT)
overwrites.set_allowance(
Expand All @@ -317,10 +323,11 @@ def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
balance_overwrites = {}
address = self.balance_owner or self.id_
for t in self.tokens:
slots = (0, 1)
slots = ERC20Slots(0, 1)
compiler = ContractCompiler.Solidity
if t.address in self.involved_contracts:
slots = self.token_storage_slots.get(t.address)
overwrites = ERC20OverwriteFactory(t, token_slots=slots)
slots, compiler = self.token_storage_slots.get(t.address)
overwrites = ERC20OverwriteFactory(t, token_slots=slots, compiler=compiler)
overwrites.set_balance(
t.to_onchain_amount(self.balances[t.address]), address
)
Expand Down
87 changes: 46 additions & 41 deletions tycho_simulation_py/python/tycho_simulation_py/evm/token.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .adapter_contract import TychoSimulationContract
from .utils import ERC20OverwriteFactory
from .utils import ContractCompiler, ERC20OverwriteFactory, ERC20Slots
from .constants import EXTERNAL_ACCOUNT
from . import SimulationEngine
from ..models import EVMBlock, EthereumToken
Expand All @@ -14,7 +14,7 @@ class SlotDetectionFailure(Exception):

def brute_force_slots(
t: EthereumToken, block: EVMBlock, engine: SimulationEngine
) -> tuple[int, int]:
) -> tuple[ERC20Slots, ContractCompiler]:
"""Brute-force detection of storage slots for token allowances and balances.
This function attempts to determine the storage slots used by the token contract for
Expand All @@ -37,9 +37,9 @@ def brute_force_slots(
Returns
-------
tuple[int, int]
A tuple containing the detected balance storage slot and the allowance
storage slot, respectively.
tuple[tuple[int, int], ContractCompiler]
A tuple containing a tuple containing the detected balance storage slot and the allowance
storage slot, respectively and in what compiler was used for this contract.
Raises
------
Expand All @@ -49,47 +49,52 @@ def brute_force_slots(
"""
token_contract = TychoSimulationContract(t.address, "ERC20", engine)
balance_slot = None
for i in range(20):
overwrite_factory = ERC20OverwriteFactory(t, (i, 1))
overwrite_factory.set_balance(_MARKER_VALUE, EXTERNAL_ACCOUNT)
res = token_contract.call(
"balanceOf",
[EXTERNAL_ACCOUNT],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_tycho_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
balance_slot = i
break
compiler = ContractCompiler.Solidity
for i in range(100):
for compiler_flag in [ContractCompiler.Solidity, ContractCompiler.Vyper]:
overwrite_factory = ERC20OverwriteFactory(t, ERC20Slots(i, 1), compiler=compiler_flag)
overwrite_factory.set_balance(_MARKER_VALUE, EXTERNAL_ACCOUNT)
res = token_contract.call(
"balanceOf",
[EXTERNAL_ACCOUNT],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_tycho_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)

allowance_slot = None
for i in range(20):
overwrite_factory = ERC20OverwriteFactory(t, (0, i))
overwrite_factory.set_allowance(_MARKER_VALUE, _SPENDER, EXTERNAL_ACCOUNT)
res = token_contract.call(
"allowance",
[EXTERNAL_ACCOUNT, _SPENDER],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_tycho_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
allowance_slot = i
break
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
balance_slot = i
compiler = compiler_flag
break

if balance_slot is None:
raise SlotDetectionFailure(f"Failed to infer balance slot for {t.address}")

allowance_slot = None
for i in range(100):
overwrite_factory = ERC20OverwriteFactory(t, ERC20Slots(0, i), compiler=compiler)
overwrite_factory.set_allowance(_MARKER_VALUE, _SPENDER, EXTERNAL_ACCOUNT)
res = token_contract.call(
"allowance",
[EXTERNAL_ACCOUNT, _SPENDER],
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrite_factory.get_tycho_overwrites(),
caller=EXTERNAL_ACCOUNT,
value=0,
)
if res.return_value is None:
continue
if res.return_value[0] == _MARKER_VALUE:
allowance_slot = i
break


if allowance_slot is None:
raise SlotDetectionFailure(f"Failed to infer allowance slot for {t.address}")

return balance_slot, allowance_slot
return (ERC20Slots(balance_slot, allowance_slot), compiler)
41 changes: 32 additions & 9 deletions tycho_simulation_py/python/tycho_simulation_py/evm/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import enum
import json
import os
from decimal import Decimal
from fractions import Fraction
from functools import lru_cache
from logging import getLogger
from typing import Final, Any
from typing import Final, Any, NamedTuple

import eth_abi
import eth_utils
Expand Down Expand Up @@ -61,9 +62,26 @@ def create_engine(

return engine

class ContractCompiler(enum.Enum):
Solidity = enum.auto()
Vyper = enum.auto()

def compute_map_slot(self, map_base_slot: bytes, key: bytes) -> bytes:
if self == ContractCompiler.Solidity:
return eth_utils.keccak(key + map_base_slot)
elif self == ContractCompiler.Vyper:
return eth_utils.keccak(map_base_slot + key)
else:
raise NotImplementedError(f"compute_map_slot not implemented for {self.name}")


class ERC20Slots(NamedTuple):
balance_map: int
allowance_map: int


class ERC20OverwriteFactory:
def __init__(self, token: EthereumToken, token_slots=(0, 1)):
def __init__(self, token: EthereumToken, token_slots: ERC20Slots = ERC20Slots(0, 1), compiler: ContractCompiler = ContractCompiler.Solidity):
"""
Initialize the ERC20OverwriteFactory.
Expand All @@ -72,8 +90,9 @@ def __init__(self, token: EthereumToken, token_slots=(0, 1)):
"""
self._token = token
self._overwrites = dict()
self._balance_slot: int = token_slots[0]
self._allowance_slot: int = token_slots[1]
self._contract_compiler = compiler
self._balance_slot: int = token_slots.balance_map
self._allowance_slot: int = token_slots.allowance_map
self._total_supply_slot: Final[int] = 2

def set_balance(self, balance: int, owner: Address):
Expand All @@ -84,7 +103,7 @@ def set_balance(self, balance: int, owner: Address):
balance: The balance value.
owner: The owner's address.
"""
storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot)
storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot, self._contract_compiler)
self._overwrites[storage_index] = balance
log.log(
5,
Expand All @@ -103,8 +122,8 @@ def set_allowance(self, allowance: int, spender: Address, owner: Address):
"""
storage_index = get_storage_slot_at_key(
HexStr(spender),
get_storage_slot_at_key(HexStr(owner), self._allowance_slot),
)
get_storage_slot_at_key(HexStr(owner), self._allowance_slot, self._contract_compiler),
self._contract_compiler)
self._overwrites[storage_index] = allowance
log.log(
5,
Expand Down Expand Up @@ -153,7 +172,7 @@ def get_geth_overwrites(self) -> dict[Address, dict[int, int]]:
return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}}


def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
def get_storage_slot_at_key(key: Address, mapping_slot: int, compiler = ContractCompiler.Solidity) -> int:
"""Get storage slot index of a value stored at a certain key in a mapping
Parameters
Expand All @@ -164,6 +183,10 @@ def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
mapping_slot
Storage slot at which the mapping itself is stored. See the examples for more
explanation.
compiler
The compiler with which the target contract was compiled. Solidity and Vyper handle
maps differently. This defaults to Solidity because it's the most used.
Returns
-------
Expand Down Expand Up @@ -193,7 +216,7 @@ def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
"""
key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0")
mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big")
slot_bytes = eth_utils.keccak(key_bytes + mapping_slot_bytes)
slot_bytes = compiler.compute_map_slot(mapping_slot_bytes, key_bytes)
return int.from_bytes(slot_bytes, "big")


Expand Down

0 comments on commit 43552b9

Please sign in to comment.