diff --git a/src/protocol/errors.rs b/src/protocol/errors.rs index 48ce6fe0..cbe87e0f 100644 --- a/src/protocol/errors.rs +++ b/src/protocol/errors.rs @@ -1,7 +1,7 @@ //! Protocol generic errors use thiserror::Error; -use super::models::GetAmountOutResult; +use super::{models::GetAmountOutResult, vm::errors::TychoSimulationError}; /// Enumeration of possible errors that can occur during a trade simulation. #[derive(Debug, PartialEq)] @@ -42,10 +42,18 @@ pub enum TransitionError { InvalidEventType(), } -#[derive(Debug, PartialEq, Error)] +#[derive(Debug, Error)] pub enum InvalidSnapshotError { #[error("Missing attributes {0}")] MissingAttribute(String), #[error("Value error {0}")] ValueError(String), + #[error("Unable to set up vm state on the engine: {0}")] + VMError(TychoSimulationError), +} + +impl From for InvalidSnapshotError { + fn from(error: TychoSimulationError) -> Self { + InvalidSnapshotError::VMError(error) + } } diff --git a/src/protocol/uniswap_v2/tycho_decoder.rs b/src/protocol/uniswap_v2/tycho_decoder.rs index a67d3788..b8ef6541 100644 --- a/src/protocol/uniswap_v2/tycho_decoder.rs +++ b/src/protocol/uniswap_v2/tycho_decoder.rs @@ -49,10 +49,6 @@ mod tests { .unwrap() .naive_utc(); //Sample timestamp - let mut static_attributes: HashMap = HashMap::new(); - static_attributes.insert("attr1".to_string(), "0x000012".into()); - static_attributes.insert("attr2".to_string(), "0x000005".into()); - ProtocolComponent { id: "State1".to_string(), protocol_system: "system1".to_string(), @@ -110,9 +106,10 @@ mod tests { let result = UniswapV2State::try_from(snapshot); assert!(result.is_err()); - assert_eq!( + + assert!(matches!( result.err().unwrap(), - InvalidSnapshotError::MissingAttribute("reserve1".to_string()) - ); + InvalidSnapshotError::MissingAttribute(attr) if attr == *"reserve1" + )); } } diff --git a/src/protocol/uniswap_v3/tycho_decoder.rs b/src/protocol/uniswap_v3/tycho_decoder.rs index 93675bf8..755c4fdf 100644 --- a/src/protocol/uniswap_v3/tycho_decoder.rs +++ b/src/protocol/uniswap_v3/tycho_decoder.rs @@ -269,10 +269,10 @@ mod tests { let result = UniswapV3State::try_from(snapshot); assert!(result.is_err()); - assert_eq!( + assert!(matches!( result.err().unwrap(), - InvalidSnapshotError::MissingAttribute(missing_attribute) - ); + InvalidSnapshotError::MissingAttribute(attr) if attr == missing_attribute + )); } #[test] @@ -295,10 +295,10 @@ mod tests { let result = UniswapV3State::try_from(snapshot); assert!(result.is_err()); - assert_eq!( + assert!(matches!( result.err().unwrap(), - InvalidSnapshotError::ValueError("Unsupported fee amount".to_string()) - ); + InvalidSnapshotError::ValueError(err) if err == *"Unsupported fee amount" + )); } #[test] diff --git a/src/protocol/vm/mod.rs b/src/protocol/vm/mod.rs index 36345caf..f06d088e 100644 --- a/src/protocol/vm/mod.rs +++ b/src/protocol/vm/mod.rs @@ -2,8 +2,9 @@ mod adapter_contract; mod constants; mod engine; mod erc20_overwrite_factory; -mod errors; +pub mod errors; mod models; -mod state; +pub mod state; +pub mod tycho_decoder; mod tycho_simulation_contract; pub mod utils; diff --git a/src/protocol/vm/state.rs b/src/protocol/vm/state.rs index 81233bc4..54dec4c4 100644 --- a/src/protocol/vm/state.rs +++ b/src/protocol/vm/state.rs @@ -1,30 +1,11 @@ // TODO: remove skip for clippy dead_code check #![allow(dead_code)] -use std::any::Any; -use tracing::warn; - -use crate::{ - evm::{ - simulation::{SimulationEngine, SimulationParameters}, - simulation_db::BlockHeader, - tycho_db::PreCachedDB, - }, - models::ERC20Token, - protocol::vm::{ - constants::{ADAPTER_ADDRESS, EXTERNAL_ACCOUNT, MAX_BALANCE}, - engine::{create_engine, SHARED_TYCHO_DB}, - errors::TychoSimulationError, - tycho_simulation_contract::TychoSimulationContract, - utils::{get_code_for_contract, get_contract_bytecode}, - }, +use std::{ + any::Any, + collections::{HashMap, HashSet}, }; -use crate::protocol::vm::{ - erc20_overwrite_factory::{ERC20OverwriteFactory, Overwrites}, - models::Capability, - utils::SlotId, -}; use chrono::Utc; use ethers::{ abi::{decode, ParamType}, @@ -34,31 +15,48 @@ use ethers::{ }; use itertools::Itertools; use revm::{ + precompile::{Address as rAddress, Bytes}, primitives::{ - alloy_primitives::Keccak256, keccak256, AccountInfo, Address as rAddress, Bytecode, Bytes, - B256, KECCAK_EMPTY, U256 as rU256, + alloy_primitives::Keccak256, keccak256, AccountInfo, Bytecode, B256, KECCAK_EMPTY, + U256 as rU256, }, DatabaseRef, }; -use std::collections::{HashMap, HashSet}; +use tracing::warn; + use tycho_core::dto::ProtocolStateDelta; -// Necessary for the init_account method to be in scope -#[allow(unused_imports)] -use crate::evm::engine_db_interface::EngineDatabaseInterface; -use crate::protocol::{ - errors::{TradeSimulationError, TransitionError}, - events::{EVMLogMeta, LogIndex}, - models::GetAmountOutResult, - state::{ProtocolEvent, ProtocolSim}, +use crate::{ + evm::{ + engine_db_interface::EngineDatabaseInterface, + simulation::{SimulationEngine, SimulationParameters}, + simulation_db::BlockHeader, + tycho_db::PreCachedDB, + }, + models::ERC20Token, + protocol::{ + errors::{TradeSimulationError, TransitionError}, + events::{EVMLogMeta, LogIndex}, + models::GetAmountOutResult, + state::{ProtocolEvent, ProtocolSim}, + vm::{ + constants::{ADAPTER_ADDRESS, EXTERNAL_ACCOUNT, MAX_BALANCE}, + engine::{create_engine, SHARED_TYCHO_DB}, + erc20_overwrite_factory::{ERC20OverwriteFactory, Overwrites}, + errors::TychoSimulationError, + models::Capability, + tycho_simulation_contract::TychoSimulationContract, + utils::{get_code_for_contract, get_contract_bytecode, SlotId}, + }, + }, }; #[derive(Clone, Debug)] pub struct VMPoolState { /// The pool's identifier pub id: String, - /// The pools tokens - pub tokens: Vec, + /// The pool's token's addresses + pub tokens: Vec, /// The current block, will be used to set vm context pub block: BlockHeader, /// The pools token balances @@ -85,6 +83,10 @@ pub struct VMPoolState { pub stateless_contracts: HashMap>>, /// If set, vm will emit detailed traces about the execution pub trace: bool, + /// Indicates if the protocol uses custom update rules and requires update + /// triggers to recalculate spot prices ect. Default is to update on all changes on + /// the pool. + pub manual_updates: bool, engine: Option>, /// The adapter contract. This is used to run simulations adapter_contract: Option>, @@ -94,16 +96,14 @@ impl VMPoolState { #[allow(clippy::too_many_arguments)] pub async fn new( id: String, - tokens: Vec, + tokens: Vec, block: BlockHeader, balances: HashMap, balance_owner: Option, adapter_contract_path: String, - capabilities: HashSet, - block_lasting_overwrites: HashMap, involved_contracts: HashSet, - token_storage_slots: HashMap, stateless_contracts: HashMap>>, + manual_updates: bool, trace: bool, ) -> Result { let mut state = VMPoolState { @@ -113,14 +113,15 @@ impl VMPoolState { balances, balance_owner, spot_prices: HashMap::new(), - capabilities, - block_lasting_overwrites, + capabilities: HashSet::new(), + block_lasting_overwrites: HashMap::new(), involved_contracts, - token_storage_slots, + token_storage_slots: HashMap::new(), stateless_contracts, trace, engine: None, adapter_contract: None, + manual_updates, }; state .set_engine(adapter_contract_path) @@ -145,7 +146,7 @@ impl VMPoolState { let token_addresses = self .tokens .iter() - .map(|token| to_checksum(&token.address, None)) + .map(|addr| to_checksum(addr, None)) .collect(); let engine: SimulationEngine<_> = create_engine(SHARED_TYCHO_DB.clone(), token_addresses, self.trace).await; @@ -308,7 +309,7 @@ impl VMPoolState { // Generate all permutations of tokens and retrieve capabilities for tokens_pair in self.tokens.iter().permutations(2) { // Manually unpack the inner vector - if let [t0, t1] = &tokens_pair[..] { + if let [t0, t1] = tokens_pair[..] { let caps = self .adapter_contract .clone() @@ -318,7 +319,7 @@ impl VMPoolState { .to_string(), ) })? - .get_capabilities(self.id.clone()[2..].to_string(), t0.address, t1.address) + .get_capabilities(self.id.clone()[2..].to_string(), *t0, *t1) .await?; capabilities.push(caps); } @@ -359,7 +360,7 @@ impl VMPoolState { .map(|p| [p[0], p[1]]) { let sell_amount_limit = self - .get_sell_amount_limit(vec![(*sell_token).clone(), (*buy_token).clone()]) + .get_sell_amount_limit(vec![(sell_token.address), (buy_token.address)]) .await?; let price_result = self .adapter_contract @@ -406,7 +407,7 @@ impl VMPoolState { /// is significant and determines the direction of the price query. async fn get_sell_amount_limit( &mut self, - tokens: Vec, + tokens: Vec, ) -> Result { let binding = self .adapter_contract @@ -419,8 +420,8 @@ impl VMPoolState { let limits = binding .get_limits( self.id.clone()[2..].to_string(), - tokens[0].address, - tokens[1].address, + tokens[0], + tokens[1], self.block.number, Some( self.get_overwrites( @@ -439,7 +440,7 @@ impl VMPoolState { pub async fn get_overwrites( &mut self, - tokens: Vec, + tokens: Vec, max_amount: U256, ) -> Result, TychoSimulationError> { let token_overwrites = self @@ -455,7 +456,7 @@ impl VMPoolState { async fn get_token_overwrites( &self, - tokens: Vec, + tokens: Vec, max_amount: U256, ) -> Result, TychoSimulationError> { let sell_token = &tokens[0].clone(); @@ -467,10 +468,10 @@ impl VMPoolState { res.push(self.get_balance_overwrites(tokens)?); } let mut overwrites = ERC20OverwriteFactory::new( - rAddress::from_slice(&sell_token.address.0), + rAddress::from_slice(&sell_token.0), *self .token_storage_slots - .get(&sell_token.address) + .get(sell_token) .unwrap_or(&(SlotId::from(0), SlotId::from(1))), ); @@ -496,7 +497,7 @@ impl VMPoolState { fn get_balance_overwrites( &self, - tokens: Vec, + tokens: Vec, ) -> Result, TychoSimulationError> { let mut balance_overwrites: HashMap = HashMap::new(); let address = match self.balance_owner { @@ -507,12 +508,9 @@ impl VMPoolState { }?; for token in &tokens { - let slots = if self - .involved_contracts - .contains(&token.address) - { + let slots = if self.involved_contracts.contains(token) { self.token_storage_slots - .get(&token.address) + .get(token) .cloned() .ok_or_else(|| { TychoSimulationError::EncodingError("Token storage slots not found".into()) @@ -521,10 +519,10 @@ impl VMPoolState { (SlotId::from(0), SlotId::from(1)) }; - let mut overwrites = ERC20OverwriteFactory::new(rAddress::from(token.address.0), slots); + let mut overwrites = ERC20OverwriteFactory::new(rAddress::from(token.0), slots); overwrites.set_balance( self.balances - .get(&token.address) + .get(token) .cloned() .unwrap_or_default(), address, @@ -611,11 +609,7 @@ impl ProtocolSim for VMPoolState { #[cfg(test)] mod tests { use super::*; - use crate::{ - evm::{simulation_db::BlockHeader, tycho_models::AccountUpdate}, - models::ERC20Token, - protocol::vm::models::Capability, - }; + use ethers::{ prelude::{H256, U256}, types::Address as EthAddress, @@ -628,6 +622,11 @@ mod tests { str::FromStr, }; + use crate::{ + evm::{simulation_db::BlockHeader, tycho_models::AccountUpdate}, + protocol::vm::models::Capability, + }; + async fn setup_db(asset_path: &Path) -> Result<(), Box> { let file = File::open(asset_path)?; let data: Value = serde_json::from_reader(file)?; @@ -676,20 +675,10 @@ mod tests { .await .expect("Failed to set up database"); - let dai = ERC20Token::new( - "0x6b175474e89094c44da98b954eedeac495271d0f", - 18, - "DAI", - U256::from(10_000), - ); - let bal = ERC20Token::new( - "0xba100000625a3754423978a60c9317c58a424e3d", - 18, - "BAL", - U256::from(10_000), - ); + let dai_addr = H160::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); + let bal_addr = H160::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(); - let tokens = vec![dai.clone(), bal.clone()]; + let tokens = vec![dai_addr, bal_addr]; let block = BlockHeader { number: 18485417, hash: H256::from_str( @@ -708,21 +697,16 @@ mod tests { block, HashMap::from([ ( - EthAddress::from(dai.address.0), + EthAddress::from(dai_addr.0), U256::from_dec_str("178754012737301807104").unwrap(), ), - ( - EthAddress::from(bal.address.0), - U256::from_dec_str("91082987763369885696").unwrap(), - ), + (EthAddress::from(bal_addr.0), U256::from_dec_str("91082987763369885696").unwrap()), ]), Some(EthAddress::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap()), "src/protocol/vm/assets/BalancerV2SwapAdapter.evm.runtime".to_string(), HashSet::new(), HashMap::new(), - HashSet::new(), - HashMap::new(), - HashMap::new(), + false, false, ) .await @@ -752,8 +736,8 @@ mod tests { .unwrap() .get_capabilities( pool_state.id[2..].to_string(), - pool_state.tokens[0].address, - pool_state.tokens[1].address, + pool_state.tokens[0], + pool_state.tokens[1], ) .await .unwrap(); @@ -777,17 +761,25 @@ mod tests { .is_err()); } + fn dai() -> ERC20Token { + ERC20Token::new("0x6b175474e89094c44da98b954eedeac495271d0f", 18, "DAI", U256::from(10_000)) + } + + fn bal() -> ERC20Token { + ERC20Token::new("0xba100000625a3754423978a60c9317c58a424e3d", 18, "BAL", U256::from(10_000)) + } + #[tokio::test] async fn test_get_sell_amount_limit() { let mut pool_state = setup_pool_state().await; let dai_limit = pool_state - .get_sell_amount_limit(vec![pool_state.tokens[0].clone(), pool_state.tokens[1].clone()]) + .get_sell_amount_limit(vec![dai().address, bal().address]) .await .unwrap(); assert_eq!(dai_limit, U256::from_dec_str("100279494253364362835").unwrap()); let bal_limit = pool_state - .get_sell_amount_limit(vec![pool_state.tokens[1].clone(), pool_state.tokens[0].clone()]) + .get_sell_amount_limit(vec![pool_state.tokens[1], pool_state.tokens[0]]) .await .unwrap(); assert_eq!(bal_limit, U256::from_dec_str("13997408640689987484").unwrap()); @@ -798,17 +790,17 @@ mod tests { let mut pool_state = setup_pool_state().await; pool_state - .set_spot_prices(pool_state.tokens.clone()) + .set_spot_prices(vec![bal(), dai()]) .await .unwrap(); let dai_bal_spot_price = pool_state .spot_prices - .get(&(pool_state.tokens[0].address, pool_state.tokens[1].address)) + .get(&(pool_state.tokens[0], pool_state.tokens[1])) .unwrap(); let bal_dai_spot_price = pool_state .spot_prices - .get(&(pool_state.tokens[1].address, pool_state.tokens[0].address)) + .get(&(pool_state.tokens[1], pool_state.tokens[0])) .unwrap(); assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9); assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246); diff --git a/src/protocol/vm/tycho_decoder.rs b/src/protocol/vm/tycho_decoder.rs new file mode 100644 index 00000000..f0360dc2 --- /dev/null +++ b/src/protocol/vm/tycho_decoder.rs @@ -0,0 +1,251 @@ +use std::{ + collections::HashMap, + time::{SystemTime, UNIX_EPOCH}, +}; + +use ethers::types::{H160, H256, U256}; + +use tycho_client::feed::{synchronizer::ComponentWithState, Header}; + +use crate::{ + evm::{simulation_db::BlockHeader, tycho_db::PreCachedDB}, + protocol::{errors::InvalidSnapshotError, vm::state::VMPoolState, BytesConvertible}, +}; + +#[allow(dead_code)] +trait TryFromWithBlock { + type Error; + async fn try_from_with_block(value: T, block: Header) -> Result + where + Self: Sized; +} + +impl From
for BlockHeader { + fn from(header: Header) -> Self { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + BlockHeader { number: header.number, hash: H256::from_bytes(&header.hash), timestamp: now } + } +} + +impl TryFromWithBlock for VMPoolState { + type Error = InvalidSnapshotError; + + /// Decodes a `ComponentWithState` into a `VMPoolState`. + /// + /// Errors with a `InvalidSnapshotError`. + async fn try_from_with_block( + snapshot: ComponentWithState, + block: Header, + ) -> Result { + let id = snapshot.component.id.clone(); + let tokens = snapshot + .component + .tokens + .clone() + .into_iter() + .map(|t| H160::from_bytes(&t)) + .collect(); + let block = BlockHeader::from(block); + let balances = snapshot + .state + .balances + .iter() + .map(|(k, v)| (H160::from_bytes(k), U256::from_bytes(v))) + .collect(); + let balance_owner = snapshot + .state + .attributes + .get("balance_owner") + .map(H160::from_bytes); + + let manual_updates = snapshot + .component + .static_attributes + .contains_key("manual_updates"); + + // Decode involved contracts + let mut stateless_contracts = HashMap::new(); + let mut index = 0; + + loop { + let address_key = format!("stateless_contract_addr_{}", index); + if let Some(encoded_address_bytes) = snapshot + .state + .attributes + .get(&address_key) + { + let encoded_address = hex::encode(encoded_address_bytes); + // Stateless contracts address are UTF-8 encoded + let address_hex = encoded_address + .strip_prefix("0x") + .unwrap_or(&encoded_address); + + let decoded = match hex::decode(address_hex) { + Ok(decoded_bytes) => match String::from_utf8(decoded_bytes) { + Ok(decoded_string) => decoded_string, + Err(_) => continue, + }, + Err(_) => continue, + }; + + let code_key = format!("stateless_contract_code_{}", index); + let code = snapshot + .state + .attributes + .get(&code_key) + .map(|value| value.to_vec()); + + stateless_contracts.insert(decoded, code); + index += 1; + } else { + break; + } + } + + let involved_contracts = snapshot + .component + .contract_ids + .iter() + .map(H160::from_bytes) + .collect(); + + let adapter_file_path = format!( + "src/protocol/vm/assets/{}", + to_adapter_file_name( + snapshot + .component + .protocol_system + .as_str(), + ) + ); + + let pool_state = VMPoolState::new( + id, + tokens, + block, + balances, + balance_owner, + adapter_file_path, + involved_contracts, + stateless_contracts, + manual_updates, + false, + ) + .await + .map_err(InvalidSnapshotError::VMError)?; + + Ok(pool_state) + } +} + +/// Converts a protocol system name to the name of the adapter file. For example, `balancer_v2` +/// would be converted to `BalancerV2SwapAdapter.evm.runtime`. +/// +/// TODO: document this requirement in a README somewhere under instructions to add support for a +/// new protocol system. +fn to_adapter_file_name(protocol_system: &str) -> String { + protocol_system + .split('_') + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } + }) + .collect::() + + "SwapAdapter.evm.runtime" +} + +#[cfg(test)] +mod tests { + use super::*; + + use chrono::DateTime; + use std::{collections::HashSet, str::FromStr}; + + use tycho_core::{ + dto::{Chain, ChangeType, ProtocolComponent, ResponseProtocolState}, + Bytes, + }; + + fn vm_component() -> ProtocolComponent { + let creation_time = DateTime::from_timestamp(1622526000, 0) + .unwrap() + .naive_utc(); //Sample timestamp + + let mut static_attributes: HashMap = HashMap::new(); + static_attributes.insert("manual_updates".to_string(), Bytes::from_str("0x01").unwrap()); + + let dai_addr = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); + let bal_addr = Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(); + let tokens = vec![dai_addr, bal_addr]; + + ProtocolComponent { + id: "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".to_string(), + protocol_system: "balancer_v2".to_string(), + protocol_type_name: "balancer_v2_pool".to_string(), + chain: Chain::Ethereum, + tokens, + contract_ids: vec![ + Bytes::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap() + ], + static_attributes, + change: ChangeType::Creation, + creation_tx: Bytes::from_str("0x0000").unwrap(), + created_at: creation_time, + } + } + + #[test] + fn test_to_adapter_file_name() { + assert_eq!(to_adapter_file_name("balancer_v2"), "BalancerV2SwapAdapter.evm.runtime"); + assert_eq!(to_adapter_file_name("uniswap_v3"), "UniswapV3SwapAdapter.evm.runtime"); + } + + #[tokio::test] + async fn test_try_from_with_block() { + let attributes: HashMap = vec![ + ( + "balance_owner".to_string(), + Bytes::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap(), + ), + ("reserve1".to_string(), Bytes::from(200_u64.to_le_bytes().to_vec())), + ] + .into_iter() + .collect(); + let snapshot = ComponentWithState { + state: ResponseProtocolState { + component_id: "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011" + .to_owned(), + attributes, + balances: HashMap::new(), + }, + component: vm_component(), + }; + + let block = Header { + number: 1, + hash: Bytes::from(vec![0; 32]), + parent_hash: Bytes::from(vec![0; 32]), + revert: false, + }; + + let result = VMPoolState::try_from_with_block(snapshot, block).await; + + assert!(result.is_ok()); + let res = result.unwrap(); + assert_eq!( + res.balance_owner, + Some(H160::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap()) + ); + let mut exp_involved_contracts = HashSet::new(); + exp_involved_contracts + .insert(H160::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap()); + assert_eq!(res.involved_contracts, exp_involved_contracts); + assert!(res.manual_updates); + } +} diff --git a/tutorial/Cargo.toml b/tutorial/Cargo.toml deleted file mode 100644 index e69de29b..00000000 diff --git a/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py b/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py index 4b7ac14a..59108933 100644 --- a/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py +++ b/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py @@ -34,21 +34,21 @@ class ThirdPartyPool: def __init__( - self, - id_: str, - tokens: tuple[EthereumToken, ...], - balances: dict[Address, Decimal], - block: EVMBlock, - adapter_contract_path: str, - marginal_prices: dict[tuple[EthereumToken, EthereumToken], Decimal] = None, - stateless_contracts: dict[str, bytes] = None, - capabilities: set[Capability] = None, - balance_owner: Optional[str] = None, - block_lasting_overwrites: defaultdict[Address, dict[int, int]] = None, - manual_updates: bool = False, - trace: bool = False, - involved_contracts=None, - token_storage_slots=None, + self, + id_: str, + tokens: tuple[EthereumToken, ...], + balances: dict[Address, Decimal], + block: EVMBlock, + adapter_contract_path: str, + marginal_prices: dict[tuple[EthereumToken, EthereumToken], Decimal] = None, + stateless_contracts: dict[str, bytes] = None, + capabilities: set[Capability] = None, + balance_owner: Optional[str] = None, + block_lasting_overwrites: defaultdict[Address, dict[int, int]] = None, + manual_updates: bool = False, + trace: bool = False, + involved_contracts=None, + token_storage_slots=None, ): self.id_ = id_ """The pools identifier.""" @@ -81,7 +81,7 @@ def __init__( contract during simulations.""" self.block_lasting_overwrites: defaultdict[Address, dict[int, int]] = ( - block_lasting_overwrites or defaultdict(dict) + block_lasting_overwrites or defaultdict(dict) ) """Storage overwrites that will be applied to all simulations. They will be cleared when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name.""" @@ -98,7 +98,7 @@ def __init__( """A set of all contract addresses involved in the simulation of this pool.""" self.token_storage_slots: dict[Address, tuple[int, int]] = ( - token_storage_slots or {} + token_storage_slots or {} ) """Allows the specification of custom storage slots for token allowances and balances. This is particularly useful for token contracts involved in protocol @@ -179,10 +179,10 @@ def _set_marginal_prices(self): block=self.block, overwrites=self.block_lasting_overwrites, )[0] - if Capability.ScaledPrice in self.capabilities: + if Capability.ScaledPrices in self.capabilities: self.marginal_prices[(t0, t1)] = frac_to_decimal(frac) else: - scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals) + scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals) self.marginal_prices[(t0, t1)] = frac_to_decimal(scaled) def _ensure_capability(self, capability: Capability): @@ -207,8 +207,8 @@ def _set_capabilities(self): def _init_token_storage_slots(self): for t in self.tokens: if ( - t.address in self.involved_contracts - and t.address not in self.token_storage_slots + t.address in self.involved_contracts + and t.address not in self.token_storage_slots ): self.token_storage_slots[t.address] = slots = token.brute_force_slots( t, self.block, self._engine @@ -216,10 +216,10 @@ def _init_token_storage_slots(self): log.debug(f"Using custom storage slots for {t.address}: {slots}") def get_amount_out( - self: TPoolState, - sell_token: EthereumToken, - sell_amount: Decimal, - buy_token: EthereumToken, + self: TPoolState, + sell_token: EthereumToken, + sell_amount: Decimal, + buy_token: EthereumToken, ) -> tuple[Decimal, int, TPoolState]: # if the pool has a hard limit and the sell amount exceeds that, simulate and # raise a partial trade @@ -236,10 +236,10 @@ def get_amount_out( return self._get_amount_out(sell_token, sell_amount, buy_token) def _get_amount_out( - self: TPoolState, - sell_token: EthereumToken, - sell_amount: Decimal, - buy_token: EthereumToken, + self: TPoolState, + sell_token: EthereumToken, + sell_amount: Decimal, + buy_token: EthereumToken, ) -> tuple[Decimal, int, TPoolState]: overwrites = self._get_overwrites(sell_token, buy_token) trade, state_changes = self._adapter_contract.swap( @@ -268,7 +268,7 @@ def _get_amount_out( return buy_amount, trade.gas_used, new_state def _get_overwrites( - self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs + self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs ) -> dict[Address, dict[int, int]]: """Get an overwrites dictionary to use in a simulation. @@ -279,7 +279,7 @@ def _get_overwrites( return _merge(self.block_lasting_overwrites.copy(), token_overwrites) def _get_token_overwrites( - self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None + self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None ) -> dict[Address, dict[int, int]]: """Creates overwrites for a token. @@ -300,7 +300,7 @@ def _get_token_overwrites( ) overwrites = ERC20OverwriteFactory( sell_token, - token_slots=self.token_storage_slots.get(sell_token.address, (0, 1)) + token_slots=self.token_storage_slots.get(sell_token.address, (0, 1)), ) overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT) overwrites.set_allowance( @@ -352,7 +352,7 @@ def _duplicate(self: "ThirdPartyPool") -> "ThirdPartyPool": ) def get_sell_amount_limit( - self, sell_token: EthereumToken, buy_token: EthereumToken + self, sell_token: EthereumToken, buy_token: EthereumToken ) -> Decimal: """ Retrieves the sell amount of the given token. diff --git a/tycho_simulation_py/python/tycho_simulation_py/models.py b/tycho_simulation_py/python/tycho_simulation_py/models.py index 03374667..09a619c8 100644 --- a/tycho_simulation_py/python/tycho_simulation_py/models.py +++ b/tycho_simulation_py/python/tycho_simulation_py/models.py @@ -113,6 +113,6 @@ class Capability(IntEnum): FeeOnTransfer = auto() ConstantPrice = auto() TokenBalanceIndependent = auto() - ScaledPrice = auto() + ScaledPrices = auto() HardLimits = auto() MarginalPrice = auto()