diff --git a/Cargo.lock b/Cargo.lock index 20e0a5bc..7fbbe3dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2158,6 +2158,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dunce" version = "1.0.5" @@ -3065,6 +3071,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "fs4" version = "0.9.1" @@ -4255,6 +4267,32 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -4829,6 +4867,32 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "pretty_assertions" version = "1.4.1" @@ -6578,6 +6642,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" @@ -7138,6 +7208,7 @@ dependencies = [ "itertools 0.10.5", "lazy_static", "mini-moka", + "mockall", "num-bigint", "num-traits", "ratatui", @@ -7602,7 +7673,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6ecc6ec3..a526614c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,9 @@ tracing-subscriber = { version = "0.3.17", default-features = false, features = ] } tempfile = "3.13.0" +# testing +mockall = "0.13" + # price_printer example clap = { version = "4.5.3", features = ["derive"] } anyhow = "1.0.79" diff --git a/src/evm/decoder.rs b/src/evm/decoder.rs index d2d427ad..15fd0e13 100644 --- a/src/evm/decoder.rs +++ b/src/evm/decoder.rs @@ -1,5 +1,5 @@ use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{hash_map::Entry, HashMap, HashSet}, future::Future, pin::Pin, str::FromStr, @@ -11,7 +11,7 @@ use thiserror::Error; use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; use tycho_client::feed::{synchronizer::ComponentWithState, FeedMessage, Header}; -use tycho_core::Bytes; +use tycho_core::{dto::ProtocolStateDelta, Bytes}; use crate::{ evm::{ @@ -36,6 +36,8 @@ pub enum StreamDecodeError { struct DecoderState { tokens: HashMap, states: HashMap>, + // maps contract address to the pools they affect + contracts_map: HashMap>, } type DecodeFut = @@ -149,6 +151,7 @@ impl TychoStreamDecoder { let mut updated_states = HashMap::new(); let mut new_pairs = HashMap::new(); let mut removed_pairs = HashMap::new(); + let mut contracts_map = HashMap::new(); let block = msg .state_msgs @@ -190,6 +193,7 @@ impl TychoStreamDecoder { } } + // Remove untracked components let state_guard = self.state.read().await; removed_pairs.extend( protocol_msg @@ -282,6 +286,20 @@ impl TychoStreamDecoder { component_tokens, ), ); + // collect contracts:ids mapping for states that should update on contract changes + for component in new_pairs.values() { + if component + .static_attributes + .contains_key("manual_updates") + { + for contract in &component.contract_ids { + contracts_map + .entry(contract.clone()) + .or_insert_with(Vec::new) + .push(id.clone()); + } + } + } // Construct state from snapshot if let Some(state_decode_f) = self.registry.get(protocol.as_str()) { @@ -333,6 +351,72 @@ impl TychoStreamDecoder { .await; info!("Engine updated"); + // update states related to contracts with account deltas + let mut pools_to_update = HashSet::new(); + // get pools related to the updated accounts + for (account, _update) in deltas.account_updates { + // get new pools related to the account updated + pools_to_update.extend( + contracts_map + .get(&account) + .cloned() + .unwrap_or_default(), + ); + // get existing pools related to the account updated + pools_to_update.extend( + state_guard + .contracts_map + .get(&account) + .cloned() + .unwrap_or_default(), + ); + } + // update the pools + for pool in pools_to_update { + match updated_states.entry(pool.clone()) { + Entry::Occupied(mut entry) => { + // if state exists in updated_states, update it + let state: &mut Box = entry.get_mut(); + state + .delta_transition( + ProtocolStateDelta::default(), + &state_guard.tokens, + ) + .map_err(|e| { + error!(pool = pool, error = ?e, "DeltaTransitionError"); + StreamDecodeError::Fatal(format!("TransitionFailure: {e:?}")) + })?; + } + Entry::Vacant(_) => { + match state_guard.states.get(&pool) { + // if state does not exist in updated_states, update the stored + // state + Some(stored_state) => { + let mut state = stored_state.clone(); + state + .delta_transition( + ProtocolStateDelta::default(), + &state_guard.tokens, + ) + .map_err(|e| { + error!(pool = pool, error = ?e, "DeltaTransitionError"); + StreamDecodeError::Fatal(format!( + "TransitionFailure: {e:?}" + )) + })?; + updated_states.insert(pool.clone(), state); + } + None => debug!( + pool = pool, + reason = "MissingState", + "DeltaTransitionError" + ), + } + } + } + } + + // update states with protocol state deltas (attribute changes etc.) for (id, update) in deltas.state_updates { match updated_states.entry(id.clone()) { Entry::Occupied(mut entry) => { @@ -372,12 +456,18 @@ impl TychoStreamDecoder { } }; } - // Persist the newly added/updated states let mut state_guard = self.state.write().await; state_guard .states .extend(updated_states.clone().into_iter()); + for (key, values) in contracts_map { + state_guard + .contracts_map + .entry(key) + .or_insert_with(Vec::new) + .extend(values); + } // Send the tick with all updated states Ok(BlockUpdate::new(block.number, updated_states, new_pairs) @@ -389,17 +479,14 @@ impl TychoStreamDecoder { mod tests { use std::{fs, path::Path}; + use mockall::predicate::*; use num_bigint::ToBigUint; use rstest::*; - use tycho_client::feed::FeedMessage; - use tycho_core::Bytes; + use super::*; use crate::{ - evm::{ - decoder::{StreamDecodeError, TychoStreamDecoder}, - protocol::uniswap_v2::state::UniswapV2State, - }, - models::Token, + evm::protocol::uniswap_v2::state::UniswapV2State, models::Token, + protocol::state::MockProtocolSim, }; async fn setup_decoder(set_tokens: bool) -> TychoStreamDecoder { @@ -525,4 +612,59 @@ mod tests { } } } + + #[tokio::test] + async fn test_decode_updates_state_on_contract_change() { + let decoder = setup_decoder(true).await; + + // Create the mock instances + let mut mock_state = MockProtocolSim::new(); + + mock_state + .expect_clone_box() + .times(1) + .returning(|| { + let mut cloned_mock_state = MockProtocolSim::new(); + // Expect `delta_transition` to be called once with any parameters + cloned_mock_state + .expect_delta_transition() + .times(1) + .returning(|_, _| Ok(())); + cloned_mock_state + .expect_clone_box() + .times(1) + .returning(|| Box::new(MockProtocolSim::new())); + Box::new(cloned_mock_state) + }); + + // Insert mock state into `updated_states` + let pool_id = + "0x93d199263632a4ef4bb438f1feb99e57b4b5f0bd0000000000000000000005c2".to_string(); + decoder + .state + .write() + .await + .states + .insert(pool_id.clone(), Box::new(mock_state) as Box); + decoder + .state + .write() + .await + .contracts_map + .insert( + Bytes::from("0xba12222222228d8ba445958a75a0704d566bf2c8").lpad(20, 0), + vec![pool_id.clone()], + ); + + // Load a test message containing a contract update + let msg = load_test_msg("balancer_v2_delta"); + + // Decode the message + let _ = decoder + .decode(msg) + .await + .expect("decode failure"); + + // The mock framework will assert that `delta_transition` was called exactly once + } } diff --git a/src/evm/protocol/vm/tycho_decoder.rs b/src/evm/protocol/vm/tycho_decoder.rs index b785205a..9c7e5c57 100644 --- a/src/evm/protocol/vm/tycho_decoder.rs +++ b/src/evm/protocol/vm/tycho_decoder.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; @@ -115,7 +115,7 @@ impl TryFromWithBlock for EVMPoolState { .contract_ids .iter() .map(|bytes: &Bytes| Address::from_slice(bytes.as_ref())) - .collect(); + .collect::>(); let protocol_name = snapshot .component @@ -228,7 +228,7 @@ mod tests { fn load_balancer_account_data() -> Vec { let project_root = env!("CARGO_MANIFEST_DIR"); let asset_path = - Path::new(project_root).join("tests/assets/decoder/balancer_snapshot.json"); + Path::new(project_root).join("tests/assets/decoder/balancer_v2_snapshot.json"); let json_data = fs::read_to_string(asset_path).expect("Failed to read test asset"); let data: Value = serde_json::from_str(&json_data).expect("Failed to parse JSON"); @@ -312,14 +312,16 @@ mod tests { .await .unwrap(); + let res_pool = res; + assert_eq!( - res.get_balance_owner(), + res_pool.get_balance_owner(), Some(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap()) ); let mut exp_involved_contracts = HashSet::new(); exp_involved_contracts .insert(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap()); - assert_eq!(res.get_involved_contracts(), exp_involved_contracts); - assert!(res.get_manual_updates()); + assert_eq!(res_pool.get_involved_contracts(), exp_involved_contracts); + assert!(res_pool.get_manual_updates()); } } diff --git a/src/protocol/state.rs b/src/protocol/state.rs index c587b925..a6a87204 100644 --- a/src/protocol/state.rs +++ b/src/protocol/state.rs @@ -46,6 +46,8 @@ //! ``` use std::{any::Any, collections::HashMap}; +#[cfg(test)] +use mockall::mock; use num_bigint::BigUint; use tycho_core::{dto::ProtocolStateDelta, Bytes}; @@ -142,3 +144,69 @@ impl Clone for Box { self.clone_box() } } + +#[cfg(test)] +mock! { + #[derive(Debug)] + pub ProtocolSim { + pub fn fee(&self) -> f64; + pub fn spot_price(&self, base: &Token, quote: &Token) -> Result; + pub fn get_amount_out( + &self, + amount_in: BigUint, + token_in: &Token, + token_out: &Token, + ) -> Result; + pub fn delta_transition( + &mut self, + delta: ProtocolStateDelta, + tokens: &HashMap, + ) -> Result<(), TransitionError>; + pub fn clone_box(&self) -> Box; + pub fn eq(&self, other: &dyn ProtocolSim) -> bool; + } +} + +#[cfg(test)] +impl ProtocolSim for MockProtocolSim { + fn fee(&self) -> f64 { + self.fee() + } + + fn spot_price(&self, base: &Token, quote: &Token) -> Result { + self.spot_price(base, quote) + } + + fn get_amount_out( + &self, + amount_in: BigUint, + token_in: &Token, + token_out: &Token, + ) -> Result { + self.get_amount_out(amount_in, token_in, token_out) + } + + fn delta_transition( + &mut self, + delta: ProtocolStateDelta, + tokens: &HashMap, + ) -> Result<(), TransitionError> { + self.delta_transition(delta, tokens) + } + + fn clone_box(&self) -> Box { + self.clone_box() + } + + fn as_any(&self) -> &dyn Any { + panic!("MockProtocolSim does not support as_any") + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + panic!("MockProtocolSim does not support as_any_mut") + } + + fn eq(&self, other: &dyn ProtocolSim) -> bool { + self.eq(other) + } +} diff --git a/tests/assets/decoder/balancer_v2_delta.json b/tests/assets/decoder/balancer_v2_delta.json new file mode 100644 index 00000000..cc05503e --- /dev/null +++ b/tests/assets/decoder/balancer_v2_delta.json @@ -0,0 +1,77 @@ +{ + "state_msgs": { + "vm:balancer_v2": { + "header": { + "hash": "0x985c985381d51f7768902baa56da51819c843d300d128a505705833ea68dc210", + "number": 21823189, + "parent_hash": "0x298b11c34ed6d8d13f5cdb9b86528a4dd2e12aedd5121ce6dc698761e1b27f6e", + "revert": false + }, + "snapshots": { + "states": {}, + "vm_storage": {} + }, + "deltas": { + "extractor": "vm:balancer_v2", + "chain": "ethereum", + "block": { + "number": 21823189, + "hash": "0x985c985381d51f7768902baa56da51819c843d300d128a505705833ea68dc210", + "parent_hash": "0x298b11c34ed6d8d13f5cdb9b86528a4dd2e12aedd5121ce6dc698761e1b27f6e", + "chain": "ethereum", + "ts": "2025-02-11T12:09:35" + }, + "finalized_block_height": 21823111, + "revert": false, + "new_tokens": {}, + "account_updates": { + "0xba12222222228d8ba445958a75a0704d566bf2c8": { + "address": "0xba12222222228d8ba445958a75a0704d566bf2c8", + "chain": "ethereum", + "slots": { + "0x43f6aa0cddef7ef5e613fa5a2609fdc84dcdcc7deb972c7b99a7def94b47469a": "0x014cfed5000000000000000000000000000000000000008d518759c491e97bb6", + "0x68937bc243fb8cd085e8097a3892e1d1ea282143f39fc35e8cd0c3f7550b8a2a": "0x014cfed500000000000000000000000000000000000001d3eb473a072ae36163" + }, + "balance": null, + "code": null, + "change": "Update" + } + }, + "state_updates": {}, + "new_protocol_components": {}, + "deleted_protocol_components": {}, + "component_balances": { + "0x93d199263632a4ef4bb438f1feb99e57b4b5f0bd0000000000000000000005c2": { + "0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0": { + "token": "0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0", + "balance": "0x01d3eb473a072ae36163", + "balance_float": 8.631583065547079e+21, + "modify_tx": "0xecd63d913fd99d950565c9639446ffc9ab354807f0af256e454bc2924b170f60", + "component_id": "0x93d199263632a4ef4bb438f1feb99e57b4b5f0bd0000000000000000000005c2" + }, + "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2": { + "token": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2", + "balance": "0x8d518759c491e97bb6", + "balance_float": 2.606865677332771e+21, + "modify_tx": "0xecd63d913fd99d950565c9639446ffc9ab354807f0af256e454bc2924b170f60", + "component_id": "0x93d199263632a4ef4bb438f1feb99e57b4b5f0bd0000000000000000000005c2" + } + } + }, + "component_tvl": { + "0x93d199263632a4ef4bb438f1feb99e57b4b5f0bd0000000000000000000005c2": 12902.997863620163 + } + }, + "removed_components": {} + } + }, + "sync_states": { + "vm:balancer_v2": { + "status": "ready", + "hash": "0x985c985381d51f7768902baa56da51819c843d300d128a505705833ea68dc210", + "number": 21823189, + "parent_hash": "0x298b11c34ed6d8d13f5cdb9b86528a4dd2e12aedd5121ce6dc698761e1b27f6e", + "revert": false + } + } +} \ No newline at end of file diff --git a/tests/assets/decoder/balancer_snapshot.json b/tests/assets/decoder/balancer_v2_snapshot.json similarity index 100% rename from tests/assets/decoder/balancer_snapshot.json rename to tests/assets/decoder/balancer_v2_snapshot.json diff --git a/tycho_simulation_py/python/test/test_third_party_pool.py b/tycho_simulation_py/python/test/test_third_party_pool.py index f53f9207..2c920760 100644 --- a/tycho_simulation_py/python/test/test_third_party_pool.py +++ b/tycho_simulation_py/python/test/test_third_party_pool.py @@ -1,6 +1,5 @@ import json from decimal import Decimal -from pathlib import Path from unittest.mock import patch, call import pytest