diff --git a/src/protocol/vm/protosim_contract.rs b/src/protocol/vm/protosim_contract.rs index 19039c0e..c47a2e8b 100644 --- a/src/protocol/vm/protosim_contract.rs +++ b/src/protocol/vm/protosim_contract.rs @@ -53,7 +53,7 @@ pub struct ProtoSimResponse { /// # Errors /// Returns errors of type `ProtosimError` when encoding, decoding, or simulation operations fail. /// These errors provide detailed feedback on potential issues. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ProtosimContract { abi: Abi, address: Address, diff --git a/src/protocol/vm/state.rs b/src/protocol/vm/state.rs index 7f79c288..bcd85406 100644 --- a/src/protocol/vm/state.rs +++ b/src/protocol/vm/state.rs @@ -1,6 +1,7 @@ // TODO: remove skip for clippy dead_code check #![allow(dead_code)] +use std::any::Any; use tracing::warn; use crate::{ @@ -40,12 +41,19 @@ use revm::{ DatabaseRef, }; use std::collections::{HashMap, HashSet}; +use tycho_core::dto::ProtocolStateDelta; // Necessary for the init_account method to be in scope #[allow(unused_imports)] use crate::evm::engine_db_interface::EngineDatabaseInterface; +use crate::protocol::{ + errors::{TradeSimulationError, TransitionError}, + events::{EVMLogMeta, LogIndex}, + models::GetAmountOutResult, + state::{ProtocolEvent, ProtocolSim}, +}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct VMPoolState { /// The pool's identifier pub id: String, @@ -59,6 +67,8 @@ pub struct VMPoolState { /// If given, balances will be overwritten here instead of on the pool contract during /// simulations pub balance_owner: Option, + /// Spot prices of the pool by token pair + pub spot_prices: HashMap<(H160, H160), f64>, /// The supported capabilities of this pool pub capabilities: HashSet, /// Storage overwrites that will be applied to all simulations. They will be cleared @@ -102,6 +112,7 @@ impl VMPoolState { block, balances, balance_owner, + spot_prices: HashMap::new(), capabilities, block_lasting_overwrites, involved_contracts, @@ -330,12 +341,8 @@ impl VMPoolState { Ok(()) } - async fn get_spot_prices( - &mut self, - tokens: Vec, - ) -> Result, ProtosimError> { + pub async fn set_spot_prices(&mut self, tokens: Vec) -> Result<(), ProtosimError> { self.ensure_capability(Capability::PriceFunction)?; - let mut spot_prices: HashMap<(ERC20Token, ERC20Token), f64> = HashMap::new(); for [sell_token, buy_token] in tokens .iter() .permutations(2) @@ -378,9 +385,10 @@ impl VMPoolState { 10f64.powi(buy_token.decimals as i32) }; - spot_prices.insert(((*sell_token).clone(), (*buy_token).clone()), price); + self.spot_prices + .insert((sell_token.address, buy_token.address), price); } - Ok(spot_prices) + Ok(()) } /// Retrieves the sell amount limit for a given pair of tokens, where the first token is treated @@ -535,6 +543,62 @@ impl VMPoolState { } } +impl ProtocolSim for VMPoolState { + fn fee(&self) -> f64 { + todo!() + } + + fn spot_price(&self, base: &ERC20Token, quote: &ERC20Token) -> f64 { + *self + .spot_prices + .get(&(base.address, quote.address)) + .expect("Spot price not found") + } + + fn get_amount_out( + &self, + _amount_in: U256, + _token_in: &ERC20Token, + _token_out: &ERC20Token, + ) -> Result { + todo!() + } + + fn delta_transition( + &mut self, + _delta: ProtocolStateDelta, + ) -> Result<(), TransitionError> { + todo!() + } + + fn event_transition( + &mut self, + _protocol_event: Box, + _log: &EVMLogMeta, + ) -> Result<(), TransitionError> { + todo!() + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn eq(&self, other: &dyn ProtocolSim) -> bool { + if let Some(other_state) = other + .as_any() + .downcast_ref::>() + { + self.id == other_state.id + } else { + false + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -721,19 +785,21 @@ mod tests { } #[tokio::test] - async fn test_get_spot_prices() { + async fn test_set_spot_prices() { let mut pool_state = setup_pool_state().await; - let spot_prices = pool_state - .get_spot_prices(pool_state.tokens.clone()) + pool_state + .set_spot_prices(pool_state.tokens.clone()) .await .unwrap(); - let dai_bal_spot_price = spot_prices - .get(&(pool_state.tokens[0].clone(), pool_state.tokens[1].clone())) + let dai_bal_spot_price = pool_state + .spot_prices + .get(&(pool_state.tokens[0].address, pool_state.tokens[1].address)) .unwrap(); - let bal_dai_spot_price = spot_prices - .get(&(pool_state.tokens[1].clone(), pool_state.tokens[0].clone())) + let bal_dai_spot_price = pool_state + .spot_prices + .get(&(pool_state.tokens[1].address, pool_state.tokens[0].address)) .unwrap(); assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9); assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);