diff --git a/src/db.rs b/src/db.rs index e57fd65..5428b65 100644 --- a/src/db.rs +++ b/src/db.rs @@ -101,7 +101,7 @@ impl KbsDb { } fn replace_binds(&self, sql: &str) -> String { - if self.r#type != DbType::PostgreSQL { + if self.r#type != DbType::Postgres { return sql.to_string(); } @@ -185,41 +185,70 @@ impl KbsDb { let allowed_policies_json = serde_json::to_string(&policy.allowed_policies)?; let allowed_build_ids_json = serde_json::to_string(&policy.allowed_build_ids)?; - let mut query_str = if self.r#type == DbType::Sqlite { - String::from("INSERT INTO policy (allowed_digests, allowed_policies, min_fw_api_major, min_fw_api_minor, allowed_build_ids, create_date, valid) VALUES(?, ?, ?, ?, ?, DATE('now'), 1)") - } else { - String::from("INSERT INTO policy (allowed_digests, allowed_policies, min_fw_api_major, min_fw_api_minor, allowed_build_ids, create_date, valid) VALUES(?, ?, ?, ?, ?, NOW(), 1)") - }; + let query = String::from( + "INSERT INTO policy + (allowed_digests, allowed_policies, min_fw_api_major, min_fw_api_minor, allowed_build_ids, create_date, valid) + VALUES(?, ?, ?, ?, ?, DATE('now'), 1)" + ); + + let policy_id = match self.r#type { + DbType::MySQL => { + let id = sqlx::query(&query) + .bind(allowed_digests_json) + .bind(allowed_policies_json) + .bind(policy.min_fw_api_major as i64) + .bind(policy.min_fw_api_minor as i64) + .bind(allowed_build_ids_json) + .execute(&self.dbpool) + .await? + .last_insert_id() + .ok_or(anyhow!("Failed to insert policy"))?; - if self.r#type == DbType::MySQL || self.r#type == DbType::Sqlite { - let last_insert_row = sqlx::query(&query_str) - .bind(allowed_digests_json) - .bind(allowed_policies_json) - .bind(policy.min_fw_api_major as i64) - .bind(policy.min_fw_api_minor as i64) - .bind(allowed_build_ids_json) - .execute(&self.dbpool) - .await? - .last_insert_id(); - match last_insert_row { - Some(p) => Ok(p as u64), - None => Err(anyhow!( - "db::insert_policy- error, last_insert_id() returned None" - )), + id as u64 } - } else { - query_str.push_str("RETURNING id"); - let new_query_str = self.replace_binds(&query_str); - let last_insert_row = sqlx::query(&new_query_str) - .bind(allowed_digests_json) - .bind(allowed_policies_json) - .bind(policy.min_fw_api_major as i64) - .bind(policy.min_fw_api_minor as i64) - .bind(allowed_build_ids_json) - .fetch_one(&self.dbpool) - .await?; - Ok(last_insert_row.try_get::(0)? as u64) - } + DbType::Sqlite => { + let query = String::from( + "INSERT INTO policy + (allowed_digests, allowed_policies, min_fw_api_major, min_fw_api_minor, allowed_build_ids, create_date, valid) + VALUES(?, ?, ?, ?, ?, DATE('now'), 1)" + ); + + sqlx::query(&query) + .bind(allowed_digests_json) + .bind(allowed_policies_json) + .bind(policy.min_fw_api_major as i64) + .bind(policy.min_fw_api_minor as i64) + .bind(allowed_build_ids_json) + .execute(&self.dbpool) + .await + .map_err(|e| anyhow!("Failed to insert policy: {e}"))?; + + // suboptimal workaround to avoid issues with last_insert_rowid() + let query = String::from("SELECT MAX(id) from policy"); + let id = sqlx::query(&query) + .fetch_one(&self.dbpool) + .await? + .try_get::(0)?; + + id as u64 + } + DbType::Postgres => { + let query = self.replace_binds(&query); + + let last_insert_row = sqlx::query(&query) + .bind(allowed_digests_json) + .bind(allowed_policies_json) + .bind(policy.min_fw_api_major as i64) + .bind(policy.min_fw_api_minor as i64) + .bind(allowed_build_ids_json) + .fetch_one(&self.dbpool) + .await?; + + last_insert_row.try_get::(0)? as u64 + } + }; + + Ok(policy_id) } pub async fn get_policy(&self, pid: u64) -> Result {