diff --git a/Cargo.toml b/Cargo.toml index bff7e01..4e94ef1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,4 +6,5 @@ members = [ "wavesexchange_topic", "wavesexchange_loaders", "wavesexchange_apis", + "wavesexchange_utils", ] diff --git a/wavesexchange_utils/Cargo.toml b/wavesexchange_utils/Cargo.toml new file mode 100644 index 0000000..aa7e5df --- /dev/null +++ b/wavesexchange_utils/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "wavesexchange_utils" +version = "0.1.0" +edition = "2021" +authors = ["Artem Sidorenko "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +thiserror = "1.0.38" +wavesexchange_log = { git = "https://github.com/waves-exchange/wavesexchange-rs", tag = "wavesexchange_log/0.5.1" } + +[dev-dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } diff --git a/wavesexchange_utils/src/circuit_breaker/error.rs b/wavesexchange_utils/src/circuit_breaker/error.rs new file mode 100644 index 0000000..ea0cbff --- /dev/null +++ b/wavesexchange_utils/src/circuit_breaker/error.rs @@ -0,0 +1,7 @@ +use std::time::Duration; + +#[derive(Debug)] +pub enum CBError { + CircuitBroke { err_count: u16, elapsed: Duration }, + Inner(E), +} diff --git a/wavesexchange_utils/src/circuit_breaker/mod.rs b/wavesexchange_utils/src/circuit_breaker/mod.rs new file mode 100644 index 0000000..a92be25 --- /dev/null +++ b/wavesexchange_utils/src/circuit_breaker/mod.rs @@ -0,0 +1,191 @@ +mod error; + +pub use error::CBError; +use wavesexchange_log::debug; + +use std::{ + future::Future, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +/// Count erroneous attempts while quering some data source. +/// +/// Example: +/// ```rust +/// use wavesexchange_utils::circuit_breaker::CircuitBreaker; +/// use std::time::Duration; +/// +/// #[tokio::main] +/// async fn main() { +/// struct Repo; +/// +/// #[derive(Debug)] +/// struct RepoError; +/// +/// let cb = CircuitBreaker::new( +/// Duration::from_secs(1), +/// 5, +/// Repo +/// ); +/// +/// cb.access(|src| async move { Err::<(), _>(RepoError) }).await.unwrap_err(); +/// cb.access(|src| async move { Ok::<_, ()>(()) }).await.unwrap() +/// +/// // see CB test below for more verbose example +/// } +/// ``` +pub struct CircuitBreaker { + /// Timespan that errors will be counted in. + /// After it elapsed, error counter will be resetted. + max_timespan: Duration, + + /// Maximum error count per timespan. Example: 3 errors per 1 sec (max_timespan) + max_err_count_per_timespan: u16, + + data_source: Arc, + + /// Current state of CB + state: Mutex, +} + +impl CircuitBreaker { + pub fn new(max_timespan: Duration, max_err_count_per_timespan: u16, data_source: S) -> Self { + Self { + max_timespan, + max_err_count_per_timespan, + data_source: Arc::new(data_source), + state: Mutex::new(CBState::default()), + } + } +} + +#[derive(Default)] +struct CBState { + err_count: u16, + first_err_ts: Option, +} + +impl CBState { + fn inc(&mut self) { + self.err_count += 1; + } + + fn reset(&mut self) { + self.err_count = 0; + self.first_err_ts = None; + } +} + +impl CircuitBreaker { + /// Access the data source. If succeeded, CB resets internal error counter. + /// If error returned, counter is increased. + /// If (N > max_err_count_per_timespan) errors appeared, CB breaks a circuit, + /// otherwise error counter will be reset. + pub async fn access(&self, query_fn: F) -> Result> + where + F: FnOnce(Arc) -> Fut, + Fut: Future>, + { + let result = query_fn(self.data_source.clone()).await; + self.handle_result(result) + } + + /// Sync version of `access` method. + pub fn access_blocking(&self, query_fn: F) -> Result> + where + F: FnOnce(Arc) -> Result, + { + let result = query_fn(self.data_source.clone()); + self.handle_result(result) + } + + fn handle_result(&self, result: Result) -> Result> { + let mut state = self.state.lock().unwrap(); + + if let Err(_) = &result { + state.inc(); + + debug!("CircuitBreaker: err count: {}", state.err_count); + + match state.first_err_ts { + Some(ts) => { + let elapsed = ts.elapsed(); + + if state.err_count <= self.max_err_count_per_timespan { + if elapsed > self.max_timespan { + state.reset(); + } + } else { + return Err(CBError::CircuitBroke { + err_count: state.err_count, + elapsed, + }); + } + } + None => state.first_err_ts = Some(Instant::now()), + } + } else { + if state.err_count > 0 { + state.reset(); + } + } + result.map_err(CBError::Inner) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct WildErrorGenerator; + + impl WildErrorGenerator { + fn err(&self) -> Result<(), WildError> { + Err(WildError) + } + } + + const EMPTY_OK: Result<(), ()> = Ok(()); + + #[derive(Debug)] + struct WildError; + + #[tokio::test] + async fn circuit_breaker() { + let cb = CircuitBreaker::new(Duration::from_secs(1), 2, WildErrorGenerator); + + // trigger 2 errors in cb + assert!(matches!( + cb.access(|weg| async move { weg.err() }).await.unwrap_err(), + CBError::Inner(WildError) + )); + + assert!(matches!( + cb.access(|weg| async move { weg.err() }).await.unwrap_err(), + CBError::Inner(WildError) + )); + + // reset cb state with successful query + assert_eq!(cb.access(|_weg| async move { EMPTY_OK }).await.unwrap(), ()); + + // trigger 3 errors in cb (max errors limit exceeded) + assert!(matches!( + cb.access(|weg| async move { weg.err() }).await.unwrap_err(), + CBError::Inner(WildError) + )); + + assert!(matches!( + cb.access(|weg| async move { weg.err() }).await.unwrap_err(), + CBError::Inner(WildError) + )); + + // break circuit + assert!(matches!( + cb.access(|weg| async move { weg.err() }).await.unwrap_err(), + CBError::CircuitBroke { .. } + )); + + assert_eq!(cb.access(|_weg| async move { EMPTY_OK }).await.unwrap(), ()); + } +} diff --git a/wavesexchange_utils/src/lib.rs b/wavesexchange_utils/src/lib.rs new file mode 100644 index 0000000..f120e4c --- /dev/null +++ b/wavesexchange_utils/src/lib.rs @@ -0,0 +1 @@ +pub mod circuit_breaker;