Skip to content

Commit

Permalink
refactor: add value reload
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Oct 10, 2024
1 parent edd67f8 commit d2b5b15
Showing 1 changed file with 77 additions and 23 deletions.
100 changes: 77 additions & 23 deletions common/src/tap/checks/value_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use sqlx::{postgres::PgListener, PgPool};
use std::{
cmp::min,
collections::HashMap,
sync::{Arc, Mutex},
str::FromStr,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
use thegraph_core::DeploymentId;
Expand All @@ -22,7 +23,7 @@ use tap_core::receipt::{
};

pub struct MinimumValue {
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
watcher_cancel_token: tokio_util::sync::CancellationToken,
}

Expand All @@ -36,7 +37,9 @@ impl Drop for MinimumValue {

impl MinimumValue {
pub async fn new(pgpool: PgPool) -> Self {
let cost_model_cache = Arc::new(Mutex::new(HashMap::<DeploymentId, CostModelCache>::new()));
let cost_model_cache = Arc::new(RwLock::new(
HashMap::<DeploymentId, Mutex<CostModelCache>>::new(),
));

let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap();
pglistener.listen("cost_models_update_notify").await.expect(
Expand All @@ -58,10 +61,23 @@ impl MinimumValue {
}
}

fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
// get agora model for the allocation_id
let cache = self.cost_model_cache.read().unwrap();
// on average, we'll have zero or one model
let models = cache.get(&agora_query.deployment_id);

let expected_value = models
.map(|cache| cache.lock().unwrap().cost(agora_query))
.unwrap_or_default();

Ok(expected_value)
}

async fn cost_models_watcher(
_pgpool: PgPool,
pgpool: PgPool,
mut pglistener: PgListener,
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
loop {
Expand All @@ -88,12 +104,12 @@ impl MinimumValue {
"INSERT" => {
let cost_model_source: CostModelSource = cost_model_notification.into();
let mut cost_model_cache = cost_model_cache
.lock()
.write()
.unwrap();

match cost_model_cache.get_mut(&deployment_id) {
Some(cache) => {
let _ = cache.insert_model(cost_model_source);
let _ = cache.lock().unwrap().insert_model(cost_model_source);
},
None => {
if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| {
Expand All @@ -102,14 +118,14 @@ impl MinimumValue {
deployment_id, err
)
}) {
cost_model_cache.insert(deployment_id, cache);
cost_model_cache.insert(deployment_id, Mutex::new(cache));
}
},
}
}
"DELETE" => {
cost_model_cache
.lock()
.write()
.unwrap()
.remove(&cost_model_notification.deployment);
}
Expand All @@ -122,29 +138,47 @@ impl MinimumValue {
cost_model_notification.tg_op
);

// Self::sender_denylist_reload(pgpool.clone(), denylist.clone())
// .await
// .expect("should be able to reload cost models")
Self::value_check_reload(&pgpool, cost_model_cache.clone())
.await
.expect("should be able to reload cost models")
}
}
}
}
}
}
}

impl MinimumValue {
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
// get agora model for the allocation_id
let mut cache = self.cost_model_cache.lock().unwrap();
// on average, we'll have zero or one model
let models = cache.get_mut(&agora_query.deployment_id);
async fn value_check_reload(
pgpool: &PgPool,
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
) -> anyhow::Result<()> {
let models = sqlx::query!(
r#"
SELECT deployment, model, variables
FROM "CostModels"
WHERE deployment != 'global'
ORDER BY deployment ASC
"#
)
.fetch_all(pgpool)
.await?;
let models = models
.into_iter()
.map(|record| {
let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?;
let model = CostModelCache::new(CostModelSource {
deployment_id,
model: record.model.unwrap(),
variables: record.variables.unwrap().to_string(),
})?;

Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model)))
})
.collect::<Result<HashMap<_, _>, _>>()?;

let expected_value = models
.map(|cache| cache.cost(agora_query))
.unwrap_or_default();
*(cost_model_cache.write().unwrap()) = models;

Ok(expected_value)
Ok(())
}
}

Expand Down Expand Up @@ -279,3 +313,23 @@ impl CostModelCache {
.unwrap_or_default()
}
}

#[cfg(test)]
mod tests {
use sqlx::PgPool;

#[sqlx::test(migrations = "../migrations")]
async fn initialize_check(pg_pool: PgPool) {}

#[sqlx::test(migrations = "../migrations")]
async fn should_initialize_check_with_caches(pg_pool: PgPool) {}

#[sqlx::test(migrations = "../migrations")]
async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {}

#[sqlx::test(migrations = "../migrations")]
async fn should_expire_old_model(pg_pool: PgPool) {}

#[sqlx::test(migrations = "../migrations")]
async fn should_verify_global_model(pg_pool: PgPool) {}
}

0 comments on commit d2b5b15

Please sign in to comment.