From 4c163820a86b2fd3fb357310471ef311d7a83911 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Tue, 17 Dec 2024 14:50:14 +0800 Subject: [PATCH] refactor: remove reduant boxing in trigger --- src/compaction/mod.rs | 2 +- src/inmem/mutable.rs | 12 +++++----- src/lib.rs | 12 +++++----- src/stream/mem_projection.rs | 2 +- src/stream/merge.rs | 10 ++++---- src/stream/package.rs | 2 +- src/trigger.rs | 46 +++++++++++++++++++----------------- 7 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/compaction/mod.rs b/src/compaction/mod.rs index c6e1c0c..95a37de 100644 --- a/src/compaction/mod.rs +++ b/src/compaction/mod.rs @@ -590,7 +590,7 @@ pub(crate) mod tests { where R: Record + Send, { - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mutable: Mutable = Mutable::new(option, trigger, fs, schema.clone()).await?; diff --git a/src/inmem/mutable.rs b/src/inmem/mutable.rs index 0212bf3..7e37a50 100644 --- a/src/inmem/mutable.rs +++ b/src/inmem/mutable.rs @@ -37,7 +37,7 @@ where { pub(crate) data: SkipMap::Key>, Option>, wal: Option, R>>>, - pub(crate) trigger: Arc + Send + Sync>>, + pub(crate) trigger: Arc>, pub(super) schema: Arc, } @@ -48,7 +48,7 @@ where { pub async fn new( option: &DbOption, - trigger: Arc + Send + Sync>>, + trigger: Arc>, fs: &Arc, schema: Arc, ) -> Result { @@ -119,7 +119,7 @@ where .map_err(|e| DbError::WalWrite(Box::new(e)))?; } - let is_exceeded = self.trigger.item(&value); + let is_exceeded = self.trigger.check_if_exceed(&value); self.data.insert(timestamped_key, value); Ok(is_exceeded) @@ -240,7 +240,7 @@ mod tests { ); fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mem_table = Mutable::::new(&option, trigger, &fs, Arc::new(TestSchema {})) .await .unwrap(); @@ -293,7 +293,7 @@ mod tests { ); fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mutable = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await @@ -389,7 +389,7 @@ mod tests { let fs = Arc::new(TokioFs) as Arc; fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let schema = Arc::new(schema); diff --git a/src/lib.rs b/src/lib.rs index a89799a..c34e4a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -472,7 +472,7 @@ where pub immutables: Vec<(Option, Immutable<::Columns>)>, compaction_tx: Sender, recover_wal_ids: Option>, - trigger: Arc + Send + Sync>>, + trigger: Arc>, record_schema: Arc, } @@ -487,7 +487,7 @@ where record_schema: Arc, manager: &StoreManager, ) -> Result> { - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mut schema = DbStorage { mutable: Mutable::new( &option, @@ -1169,7 +1169,7 @@ pub(crate) mod tests { option: Arc, fs: &Arc, ) -> Result<(crate::DbStorage, Receiver), fusio::Error> { - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mutable = Mutable::new(&option, trigger.clone(), fs, Arc::new(TestSchema {})).await?; @@ -1211,7 +1211,7 @@ pub(crate) mod tests { .unwrap(); let immutables = { - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mutable: Mutable = Mutable::new(&option, trigger.clone(), fs, Arc::new(TestSchema)).await?; @@ -1648,7 +1648,7 @@ pub(crate) mod tests { let (task_tx, _task_rx) = bounded(1); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let schema: crate::DbStorage = crate::DbStorage { mutable: Mutable::new(&option, trigger.clone(), &fs, Arc::new(TestSchema)) .await @@ -1718,7 +1718,7 @@ pub(crate) mod tests { let (task_tx, _task_rx) = bounded(1); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let schema: crate::DbStorage = crate::DbStorage { mutable: Mutable::new( &option, diff --git a/src/stream/mem_projection.rs b/src/stream/mem_projection.rs index c22de63..fda5e60 100644 --- a/src/stream/mem_projection.rs +++ b/src/stream/mem_projection.rs @@ -83,7 +83,7 @@ mod tests { fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let mutable = Mutable::::new(&option, trigger, &fs, Arc::new(TestSchema {})) .await diff --git a/src/stream/merge.rs b/src/stream/merge.rs index d4b3f5a..722d06e 100644 --- a/src/stream/merge.rs +++ b/src/stream/merge.rs @@ -178,7 +178,7 @@ mod tests { fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m1 = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await @@ -194,7 +194,7 @@ mod tests { .await .unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m2 = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await @@ -209,7 +209,7 @@ mod tests { .await .unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m3 = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await @@ -281,7 +281,7 @@ mod tests { fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m1 = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await @@ -372,7 +372,7 @@ mod tests { fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m1 = Mutable::::new(&option, trigger, &fs, Arc::new(StringSchema)) .await diff --git a/src/stream/package.rs b/src/stream/package.rs index 39ba16f..31a16bc 100644 --- a/src/stream/package.rs +++ b/src/stream/package.rs @@ -115,7 +115,7 @@ mod tests { fs.create_dir_all(&option.wal_dir_path()).await.unwrap(); - let trigger = Arc::new(TriggerFactory::create(option.trigger_type)); + let trigger = TriggerFactory::create(option.trigger_type); let m1 = Mutable::::new(&option, trigger, &fs, Arc::new(TestSchema {})) .await diff --git a/src/trigger.rs b/src/trigger.rs index 1fbf616..9af96be 100644 --- a/src/trigger.rs +++ b/src/trigger.rs @@ -1,17 +1,19 @@ use std::{ - fmt, - fmt::Debug, marker::PhantomData, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; use crate::record::Record; -pub trait Trigger: fmt::Debug { - fn item(&self, item: &Option) -> bool; +pub trait Trigger: Send + Sync { + fn check_if_exceed(&self, item: &Option) -> bool; fn reset(&self); } + #[derive(Debug)] pub struct SizeOfMemTrigger { threshold: usize, @@ -30,7 +32,7 @@ impl SizeOfMemTrigger { } impl Trigger for SizeOfMemTrigger { - fn item(&self, item: &Option) -> bool { + fn check_if_exceed(&self, item: &Option) -> bool { let size = item.as_ref().map_or(0, R::size); self.current_size.fetch_add(size, Ordering::SeqCst) + size >= self.threshold } @@ -58,7 +60,7 @@ impl LengthTrigger { } impl Trigger for LengthTrigger { - fn item(&self, _: &Option) -> bool { + fn check_if_exceed(&self, _: &Option) -> bool { self.count.fetch_add(1, Ordering::SeqCst) + 1 >= self.threshold } @@ -78,10 +80,10 @@ pub(crate) struct TriggerFactory { } impl TriggerFactory { - pub fn create(trigger_type: TriggerType) -> Box + Send + Sync> { + pub fn create(trigger_type: TriggerType) -> Arc> { match trigger_type { - TriggerType::SizeOfMem(threshold) => Box::new(SizeOfMemTrigger::new(threshold)), - TriggerType::Length(threshold) => Box::new(LengthTrigger::new(threshold)), + TriggerType::SizeOfMem(threshold) => Arc::new(SizeOfMemTrigger::new(threshold)), + TriggerType::Length(threshold) => Arc::new(LengthTrigger::new(threshold)), } } } @@ -106,19 +108,19 @@ mod tests { assert_eq!(record_size, 8); assert!( - !trigger.item(&record), + !trigger.check_if_exceed(&record), "Trigger should not be exceeded after 1 record" ); - trigger.item(&record); + trigger.check_if_exceed(&record); assert!( - trigger.item(&record), + trigger.check_if_exceed(&record), "Trigger should be exceeded after 2 records" ); trigger.reset(); assert!( - !trigger.item(&record), + !trigger.check_if_exceed(&record), "Trigger should not be exceeded after reset" ); } @@ -135,19 +137,19 @@ mod tests { }); assert!( - !trigger.item(&record), + !trigger.check_if_exceed(&record), "Trigger should not be exceeded after 1 record" ); - trigger.item(&record); + trigger.check_if_exceed(&record); assert!( - trigger.item(&record), + trigger.check_if_exceed(&record), "Trigger should be exceeded after 2 records" ); trigger.reset(); assert!( - !trigger.item(&record), + !trigger.check_if_exceed(&record), "Trigger should not be exceeded after reset" ); } @@ -156,23 +158,23 @@ mod tests { let size_of_mem_trigger = TriggerFactory::::create(TriggerType::SizeOfMem(9)); let length_trigger = TriggerFactory::::create(TriggerType::Length(2)); - assert!(!size_of_mem_trigger.item(&Some(Test { + assert!(!size_of_mem_trigger.check_if_exceed(&Some(Test { vstring: "test".to_string(), vu32: 0, vbool: None }))); - assert!(size_of_mem_trigger.item(&Some(Test { + assert!(size_of_mem_trigger.check_if_exceed(&Some(Test { vstring: "test".to_string(), vu32: 0, vbool: None }))); - assert!(!length_trigger.item(&Some(Test { + assert!(!length_trigger.check_if_exceed(&Some(Test { vstring: "test".to_string(), vu32: 1, vbool: Some(true) }))); - assert!(length_trigger.item(&Some(Test { + assert!(length_trigger.check_if_exceed(&Some(Test { vstring: "test".to_string(), vu32: 1, vbool: Some(true)