From d2b5b15a7428a25e2fd365e599548af8224295a3 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Tue, 8 Oct 2024 15:29:16 +0200 Subject: [PATCH] refactor: add value reload Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 100 +++++++++++++++++++++------ 1 file changed, 77 insertions(+), 23 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 396e35c0..d28c5c31 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -8,7 +8,8 @@ use sqlx::{postgres::PgListener, PgPool}; use std::{ cmp::min, collections::HashMap, - sync::{Arc, Mutex}, + str::FromStr, + sync::{Arc, Mutex, RwLock}, time::Duration, }; use thegraph_core::DeploymentId; @@ -22,7 +23,7 @@ use tap_core::receipt::{ }; pub struct MinimumValue { - cost_model_cache: Arc>>, + cost_model_cache: Arc>>>, watcher_cancel_token: tokio_util::sync::CancellationToken, } @@ -36,7 +37,9 @@ impl Drop for MinimumValue { impl MinimumValue { pub async fn new(pgpool: PgPool) -> Self { - let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); + let cost_model_cache = Arc::new(RwLock::new( + HashMap::>::new(), + )); let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); pglistener.listen("cost_models_update_notify").await.expect( @@ -58,10 +61,23 @@ impl MinimumValue { } } + fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { + // get agora model for the allocation_id + let cache = self.cost_model_cache.read().unwrap(); + // on average, we'll have zero or one model + let models = cache.get(&agora_query.deployment_id); + + let expected_value = models + .map(|cache| cache.lock().unwrap().cost(agora_query)) + .unwrap_or_default(); + + Ok(expected_value) + } + async fn cost_models_watcher( - _pgpool: PgPool, + pgpool: PgPool, mut pglistener: PgListener, - cost_model_cache: Arc>>, + cost_model_cache: Arc>>>, cancel_token: tokio_util::sync::CancellationToken, ) { loop { @@ -88,12 +104,12 @@ impl MinimumValue { "INSERT" => { let cost_model_source: CostModelSource = cost_model_notification.into(); let mut cost_model_cache = cost_model_cache - .lock() + .write() .unwrap(); match cost_model_cache.get_mut(&deployment_id) { Some(cache) => { - let _ = cache.insert_model(cost_model_source); + let _ = cache.lock().unwrap().insert_model(cost_model_source); }, None => { if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| { @@ -102,14 +118,14 @@ impl MinimumValue { deployment_id, err ) }) { - cost_model_cache.insert(deployment_id, cache); + cost_model_cache.insert(deployment_id, Mutex::new(cache)); } }, } } "DELETE" => { cost_model_cache - .lock() + .write() .unwrap() .remove(&cost_model_notification.deployment); } @@ -122,29 +138,47 @@ impl MinimumValue { cost_model_notification.tg_op ); - // Self::sender_denylist_reload(pgpool.clone(), denylist.clone()) - // .await - // .expect("should be able to reload cost models") + Self::value_check_reload(&pgpool, cost_model_cache.clone()) + .await + .expect("should be able to reload cost models") } } } } } } -} -impl MinimumValue { - fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { - // get agora model for the allocation_id - let mut cache = self.cost_model_cache.lock().unwrap(); - // on average, we'll have zero or one model - let models = cache.get_mut(&agora_query.deployment_id); + async fn value_check_reload( + pgpool: &PgPool, + cost_model_cache: Arc>>>, + ) -> anyhow::Result<()> { + let models = sqlx::query!( + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment != 'global' + ORDER BY deployment ASC + "# + ) + .fetch_all(pgpool) + .await?; + let models = models + .into_iter() + .map(|record| { + let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?; + let model = CostModelCache::new(CostModelSource { + deployment_id, + model: record.model.unwrap(), + variables: record.variables.unwrap().to_string(), + })?; + + Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model))) + }) + .collect::, _>>()?; - let expected_value = models - .map(|cache| cache.cost(agora_query)) - .unwrap_or_default(); + *(cost_model_cache.write().unwrap()) = models; - Ok(expected_value) + Ok(()) } } @@ -279,3 +313,23 @@ impl CostModelCache { .unwrap_or_default() } } + +#[cfg(test)] +mod tests { + use sqlx::PgPool; + + #[sqlx::test(migrations = "../migrations")] + async fn initialize_check(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_initialize_check_with_caches(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_expire_old_model(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_verify_global_model(pg_pool: PgPool) {} +}