Skip to content

Commit

Permalink
refactor: update cost model to use history
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Oct 10, 2024
1 parent 0ae3eb8 commit edd67f8
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 48 deletions.
27 changes: 13 additions & 14 deletions common/src/tap/checks/value_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::{
time::Duration,
};
use thegraph_core::DeploymentId;
use tokio::task::JoinHandle;
use tracing::error;
use ttl_cache::TtlCache;

Expand All @@ -24,7 +23,15 @@ use tap_core::receipt::{

pub struct MinimumValue {
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
model_handle: JoinHandle<()>,
watcher_cancel_token: tokio_util::sync::CancellationToken,
}

impl Drop for MinimumValue {
fn drop(&mut self) {
// Clean shutdown for the sender_denylist_watcher
// Though since it's not a critical task, we don't wait for it to finish (join).
self.watcher_cancel_token.cancel();
}
}

impl MinimumValue {
Expand All @@ -37,19 +44,17 @@ impl MinimumValue {
'cost_models_update_notify'",
);

// TODO start watcher
let cancel_token = tokio_util::sync::CancellationToken::new();

let model_handle = tokio::spawn(Self::cost_models_watcher(
let watcher_cancel_token = tokio_util::sync::CancellationToken::new();
tokio::spawn(Self::cost_models_watcher(
pgpool.clone(),
pglistener,
cost_model_cache.clone(),
cancel_token.clone(),
watcher_cancel_token.clone(),
));

Self {
cost_model_cache,
model_handle,
watcher_cancel_token,
}
}

Expand Down Expand Up @@ -128,12 +133,6 @@ impl MinimumValue {
}
}

impl Drop for MinimumValue {
fn drop(&mut self) {
self.model_handle.abort();
}
}

impl MinimumValue {
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
// get agora model for the allocation_id
Expand Down
8 changes: 7 additions & 1 deletion migrations/20230901142040_cost_models.down.sql
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
-- Add down migration script here
DROP TABLE "CostModels";
DROP TRIGGER IF EXISTS cost_models_update ON "CostModelsHistory" CASCADE;

DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE;

DROP VIEW "CostModels";

DROP TABLE "CostModelsHistory";
43 changes: 40 additions & 3 deletions migrations/20230901142040_cost_models.up.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
CREATE TABLE IF NOT EXISTS "CostModels"
CREATE TABLE IF NOT EXISTS "CostModelsHistory"
(
id INT,
id SERIAL PRIMARY KEY,
deployment VARCHAR NOT NULL,
model TEXT,
variables JSONB,
PRIMARY KEY( deployment )
"createdAt" TIMESTAMP WITH TIME ZONE,
"updatedAt" TIMESTAMP WITH TIME ZONE
);

CREATE VIEW "CostModels" AS SELECT id,
deployment,
model,
variables,
"createdAt",
"updatedAt"
FROM "CostModelsHistory" t1
JOIN
(
SELECT MAX(id)
FROM "CostModelsHistory"
GROUP BY deployment
) t2
ON t1.id = t2.MAX;

CREATE FUNCTION cost_models_update_notify()
RETURNS trigger AS
$$
BEGIN
IF TG_OP = 'DELETE' THEN
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment));
RETURN OLD;
ELSIF TG_OP = 'INSERT' THEN
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model));
RETURN NEW;
ELSE
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model));
RETURN NEW;
END IF;
END;
$$ LANGUAGE 'plpgsql';

CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE
ON "CostModelsHistory"
FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify();
44 changes: 14 additions & 30 deletions service/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn connect(url: &str) -> PgPool {
/// These can have "global" as the deployment ID.
#[derive(Debug, Clone)]
struct DbCostModel {
pub deployment: String,
pub deployment: Option<String>,
pub model: Option<String>,
pub variables: Option<Value>,
}
Expand All @@ -46,7 +46,12 @@ impl TryFrom<DbCostModel> for CostModel {

fn try_from(db_model: DbCostModel) -> Result<Self, Self::Error> {
Ok(Self {
deployment: DeploymentId::from_str(&db_model.deployment)?,
deployment: DeploymentId::from_str(&db_model.deployment.ok_or(
ParseDeploymentIdError::InvalidIpfsHashLength {
value: String::new(),
length: 0,
},
)?)?,
model: db_model.model,
variables: db_model.variables,
})
Expand All @@ -57,7 +62,7 @@ impl From<CostModel> for DbCostModel {
fn from(model: CostModel) -> Self {
let deployment = model.deployment;
DbCostModel {
deployment: format!("{deployment:#x}"),
deployment: Some(format!("{deployment:#x}")),
model: model.model,
variables: model.variables,
}
Expand Down Expand Up @@ -210,28 +215,11 @@ mod test {

use super::*;

async fn setup_cost_models_table(pool: &PgPool) {
sqlx::query!(
r#"
CREATE TABLE "CostModels"(
id INT,
deployment VARCHAR NOT NULL,
model TEXT,
variables JSONB,
PRIMARY KEY( deployment )
);
"#,
)
.execute(pool)
.await
.expect("Create test instance in db");
}

async fn add_cost_models(pool: &PgPool, models: Vec<DbCostModel>) {
for model in models {
sqlx::query!(
r#"
INSERT INTO "CostModels" (deployment, model)
INSERT INTO "CostModelsHistory" (deployment, model)
VALUES ($1, $2);
"#,
model.deployment,
Expand All @@ -249,7 +237,7 @@ mod test {

fn global_cost_model() -> DbCostModel {
DbCostModel {
deployment: "global".to_string(),
deployment: Some("global".to_string()),
model: Some("default => 0.00001;".to_string()),
variables: None,
}
Expand Down Expand Up @@ -281,15 +269,14 @@ mod test {
]
}

#[sqlx::test]
#[sqlx::test(migrations = "../migrations")]
async fn success_cost_models(pool: PgPool) {
let test_models = test_data();
let test_deployments = test_models
.iter()
.map(|model| model.deployment)
.collect::<HashSet<_>>();

setup_cost_models_table(&pool).await;
add_cost_models(&pool, to_db_models(test_models.clone())).await;

// First test: query without deployment filter
Expand Down Expand Up @@ -344,7 +331,7 @@ mod test {
}
}

#[sqlx::test]
#[sqlx::test(migrations = "../migrations")]
async fn global_fallback_cost_models(pool: PgPool) {
let test_models = test_data();
let test_deployments = test_models
Expand All @@ -353,7 +340,6 @@ mod test {
.collect::<HashSet<_>>();
let global_model = global_cost_model();

setup_cost_models_table(&pool).await;
add_cost_models(&pool, to_db_models(test_models.clone())).await;
add_cost_models(&pool, vec![global_model.clone()]).await;

Expand Down Expand Up @@ -436,9 +422,8 @@ mod test {
assert_eq!(missing_model.model, global_model.model);
}

#[sqlx::test]
#[sqlx::test(migrations = "../migrations")]
async fn success_cost_model(pool: PgPool) {
setup_cost_models_table(&pool).await;
add_cost_models(&pool, to_db_models(test_data())).await;

let deployment_id_from_bytes = DeploymentId::from_str(
Expand All @@ -459,12 +444,11 @@ mod test {
assert_eq!(model.model, Some("default => 0.00025;".to_string()));
}

#[sqlx::test]
#[sqlx::test(migrations = "../migrations")]
async fn global_fallback_cost_model(pool: PgPool) {
let test_models = test_data();
let global_model = global_cost_model();

setup_cost_models_table(&pool).await;
add_cost_models(&pool, to_db_models(test_models.clone())).await;
add_cost_models(&pool, vec![global_model.clone()]).await;

Expand Down

0 comments on commit edd67f8

Please sign in to comment.