diff --git a/Cargo.lock b/Cargo.lock index 583b0ee602..b7e5f9a27e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3697,6 +3697,15 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "lru" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60" +dependencies = [ + "hashbrown 0.14.0", +] + [[package]] name = "mach2" version = "0.4.1" @@ -4131,7 +4140,7 @@ checksum = "5f4e3bc495f6e95bc15a6c0c55ac00421504a5a43d09e3cc455d1fea7015581d" dependencies = [ "bitvec", "either", - "lru", + "lru 0.7.8", "num-bigint", "num-integer", "num-modular", @@ -4623,6 +4632,7 @@ dependencies = [ "indexmap 1.9.3", "itertools 0.10.5", "libmdbx", + "lru 0.12.0", "metrics", "mockall", "papyrus_base_layer", diff --git a/Cargo.toml b/Cargo.toml index 73e46d61c9..b87bcd49b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ jsonrpsee = "0.18.1" jsonschema = "0.17.0" lazy_static = "1.4.0" libmdbx = "0.3.5" +lru = "0.12.0" memmap2 = "0.8.0" metrics = "0.21.0" metrics-exporter-prometheus = "0.12.1" diff --git a/config/default_config.json b/config/default_config.json index ff2e06a861..0b17177df9 100644 --- a/config/default_config.json +++ b/config/default_config.json @@ -9,6 +9,11 @@ "privacy": "Public", "value": "0xc662c410C0ECf747543f5bA90660f6ABeBD9C8c4" }, + "central.class_cache_size": { + "description": "Size of class cache, must be a positive integer.", + "privacy": "Public", + "value": 30 + }, "central.concurrent_requests": { "description": "Maximum number of concurrent requests to Starknet feeder-gateway for getting a type of data (for example, blocks).", "privacy": "Public", @@ -189,4 +194,4 @@ "privacy": "Public", "value": 1000 } -} \ No newline at end of file +} diff --git a/crates/papyrus_node/src/config/snapshots/papyrus_node__config__config_test__dump_default_config.snap b/crates/papyrus_node/src/config/snapshots/papyrus_node__config__config_test__dump_default_config.snap index 45e92dec41..e8dee38236 100644 --- a/crates/papyrus_node/src/config/snapshots/papyrus_node__config__config_test__dump_default_config.snap +++ b/crates/papyrus_node/src/config/snapshots/papyrus_node__config__config_test__dump_default_config.snap @@ -13,6 +13,13 @@ expression: dumped_default_config "value": "0xc662c410C0ECf747543f5bA90660f6ABeBD9C8c4", "privacy": "Public" }, + "central.class_cache_size": { + "description": "Size of class cache, must be a positive integer.", + "value": { + "$serde_json::private::Number": "30" + }, + "privacy": "Public" + }, "central.concurrent_requests": { "description": "Maximum number of concurrent requests to Starknet feeder-gateway for getting a type of data (for example, blocks).", "value": { diff --git a/crates/papyrus_sync/Cargo.toml b/crates/papyrus_sync/Cargo.toml index d060238487..ff899276d2 100644 --- a/crates/papyrus_sync/Cargo.toml +++ b/crates/papyrus_sync/Cargo.toml @@ -17,6 +17,7 @@ hex.workspace = true indexmap = { workspace = true, features = ["serde"] } itertools.workspace = true libmdbx = { workspace = true, features = ["lifetimed-bytes"] } +lru.workspace = true metrics.workspace = true papyrus_storage = { path = "../papyrus_storage", version = "0.0.5" } papyrus_base_layer = { path = "../papyrus_base_layer" } diff --git a/crates/papyrus_sync/src/sources/central.rs b/crates/papyrus_sync/src/sources/central.rs index 842a68f155..990a26418b 100644 --- a/crates/papyrus_sync/src/sources/central.rs +++ b/crates/papyrus_sync/src/sources/central.rs @@ -4,7 +4,8 @@ mod central_test; mod state_update_stream; use std::collections::{BTreeMap, HashMap}; -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use async_stream::stream; use async_trait::async_trait; @@ -13,6 +14,7 @@ use futures::stream::BoxStream; use futures_util::StreamExt; use indexmap::IndexMap; use itertools::chain; +use lru::LruCache; #[cfg(test)] use mockall::automock; use papyrus_common::BlockHashAndNumber; @@ -49,6 +51,8 @@ pub struct CentralSourceConfig { pub max_state_updates_to_download: usize, pub max_state_updates_to_store_in_memory: usize, pub max_classes_to_download: usize, + // TODO(dan): validate that class_cache_size is a positive integer. + pub class_cache_size: usize, pub retry_config: RetryConfig, } @@ -61,6 +65,7 @@ impl Default for CentralSourceConfig { max_state_updates_to_download: 20, max_state_updates_to_store_in_memory: 20, max_classes_to_download: 20, + class_cache_size: 30, retry_config: RetryConfig { retry_base_millis: 30, retry_max_delay_millis: 30000, @@ -110,6 +115,12 @@ impl SerializeConfig for CentralSourceConfig { "Maximum number of classes to download at a given time.", ParamPrivacyInput::Public, ), + ser_param( + "class_cache_size", + &self.class_cache_size, + "Size of class cache, must be a positive integer.", + ParamPrivacyInput::Public, + ), ]); chain!(self_params_dump, append_sub_config_name(self.retry_config.dump(), "retry_config")) .collect() @@ -121,10 +132,11 @@ pub struct GenericCentralSource { pub starknet_client: Arc, pub storage_reader: StorageReader, pub state_update_stream_config: StateUpdateStreamConfig, + pub(crate) class_cache: Arc>>, } #[derive(Clone)] -enum ApiContractClass { +pub(crate) enum ApiContractClass { DeprecatedContractClass(starknet_api::deprecated_contract_class::ContractClass), ContractClass(starknet_api::state::ContractClass), } @@ -251,6 +263,7 @@ impl CentralSourceTrait self.starknet_client.clone(), self.storage_reader.clone(), self.state_update_stream_config.clone(), + self.class_cache.clone(), ) .boxed() } @@ -395,6 +408,10 @@ impl CentralSource { max_state_updates_to_store_in_memory: config.max_state_updates_to_store_in_memory, max_classes_to_download: config.max_classes_to_download, }, + class_cache: Arc::from(Mutex::new(LruCache::new( + NonZeroUsize::new(config.class_cache_size) + .expect("class_cache_size should be a positive integer."), + ))), }) } } diff --git a/crates/papyrus_sync/src/sources/central/state_update_stream.rs b/crates/papyrus_sync/src/sources/central/state_update_stream.rs index 14c4227115..747ac10cab 100644 --- a/crates/papyrus_sync/src/sources/central/state_update_stream.rs +++ b/crates/papyrus_sync/src/sources/central/state_update_stream.rs @@ -1,11 +1,12 @@ use std::collections::VecDeque; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::task::Poll; use futures_util::stream::FuturesOrdered; use futures_util::{Future, Stream, StreamExt}; use indexmap::IndexMap; +use lru::LruCache; use papyrus_storage::state::StateStorageReader; use papyrus_storage::StorageReader; use starknet_api::block::BlockNumber; @@ -39,6 +40,7 @@ pub(crate) struct StateUpdateStream, download_class_tasks: TasksQueue>>, downloaded_classes: VecDeque, + class_cache: Arc>>, config: StateUpdateStreamConfig, } @@ -89,6 +91,7 @@ impl StateUpdateStream< starknet_client: Arc, storage_reader: StorageReader, config: StateUpdateStreamConfig, + class_cache: Arc>>, ) -> Self { StateUpdateStream { initial_block_number, @@ -107,6 +110,7 @@ impl StateUpdateStream< config.max_state_updates_to_store_in_memory * 5, ), config, + class_cache, } } @@ -151,7 +155,9 @@ impl StateUpdateStream< }; let starknet_client = self.starknet_client.clone(); let storage_reader = self.storage_reader.clone(); + let cache = self.class_cache.clone(); self.download_class_tasks.push_back(Box::pin(download_class_if_necessary( + cache, class_hash, starknet_client, storage_reader, @@ -330,10 +336,18 @@ fn client_to_central_state_update( // If not found in the storage, the class is downloaded. #[instrument(skip(starknet_client, storage_reader), level = "debug", err)] async fn download_class_if_necessary( + cache: Arc>>, class_hash: ClassHash, starknet_client: Arc, storage_reader: StorageReader, ) -> CentralResult> { + { + let mut cache = cache.lock().expect("Failed to lock class cache."); + if let Some(class) = cache.get(&class_hash) { + return Ok(Some(class.clone())); + } + } + let txn = storage_reader.begin_ro_txn()?; let state_reader = txn.get_state_reader()?; let block_number = txn.get_state_marker()?; @@ -342,6 +356,10 @@ async fn download_class_if_necessary( // Check declared classes. if let Ok(Some(class)) = state_reader.get_class_definition_at(state_number, &class_hash) { trace!("Class {:?} retrieved from storage.", class_hash); + { + let mut cache = cache.lock().expect("Failed to lock class cache."); + cache.put(class_hash, ApiContractClass::ContractClass(class.clone())); + } return Ok(Some(ApiContractClass::ContractClass(class))); }; @@ -350,6 +368,10 @@ async fn download_class_if_necessary( state_reader.get_deprecated_class_definition_at(state_number, &class_hash) { trace!("Deprecated class {:?} retrieved from storage.", class_hash); + { + let mut cache = cache.lock().expect("Failed to lock class cache."); + cache.put(class_hash, ApiContractClass::DeprecatedContractClass(class.clone())); + } return Ok(Some(ApiContractClass::DeprecatedContractClass(class))); } @@ -358,6 +380,12 @@ async fn download_class_if_necessary( let client_class = starknet_client.class_by_hash(class_hash).await.map_err(Arc::new)?; match client_class { None => Ok(None), - Some(class) => Ok(Some(class.into())), + Some(class) => { + { + let mut cache = cache.lock().expect("Failed to lock class cache."); + cache.put(class_hash, class.clone().into()); + } + Ok(Some(class.into())) + } } } diff --git a/crates/papyrus_sync/src/sources/central_test.rs b/crates/papyrus_sync/src/sources/central_test.rs index ce3af9b1f3..8083c41ec4 100644 --- a/crates/papyrus_sync/src/sources/central_test.rs +++ b/crates/papyrus_sync/src/sources/central_test.rs @@ -1,9 +1,11 @@ -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use assert_matches::assert_matches; use cairo_lang_starknet::casm_contract_class::CasmContractClass; use futures_util::pin_mut; use indexmap::{indexmap, IndexMap}; +use lru::LruCache; use mockall::predicate; use papyrus_storage::state::StateStorageWriter; use papyrus_storage::test_utils::get_test_storage; @@ -38,6 +40,7 @@ use starknet_client::ClientError; use tokio_stream::StreamExt; use super::state_update_stream::StateUpdateStreamConfig; +use super::ApiContractClass; use crate::sources::central::{CentralError, CentralSourceTrait, GenericCentralSource}; const TEST_CONCURRENT_REQUESTS: usize = 300; @@ -58,6 +61,7 @@ async fn last_block_number() { concurrent_requests: TEST_CONCURRENT_REQUESTS, storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let last_block_number = central_source.get_latest_block().await.unwrap().unwrap().block_number; @@ -83,6 +87,7 @@ async fn stream_block_headers() { starknet_client: Arc::new(mock), storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -120,6 +125,7 @@ async fn stream_block_headers_some_are_missing() { starknet_client: Arc::new(mock), storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -172,6 +178,7 @@ async fn stream_block_headers_error() { starknet_client: Arc::new(mock), storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -308,6 +315,7 @@ async fn stream_state_updates() { starknet_client: Arc::new(mock), storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let initial_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -418,6 +426,7 @@ async fn stream_compiled_classes() { starknet_client: Arc::new(mock), storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), + class_cache: get_test_class_cache(), }; let stream = central_source.stream_compiled_classes(BlockNumber(0), BlockNumber(2)); @@ -442,3 +451,7 @@ fn state_update_stream_config_for_test() -> StateUpdateStreamConfig { max_classes_to_download: 10, } } + +fn get_test_class_cache() -> Arc>> { + Arc::from(Mutex::new(LruCache::new(NonZeroUsize::new(2).unwrap()))) +}