From fb2b20f24eb4da48f2c9e944860d5083a4aaaf9e Mon Sep 17 00:00:00 2001 From: Martin Raszyk Date: Mon, 13 Jan 2025 13:17:19 +0100 Subject: [PATCH] chore(upgrader): assume that target canister is always set in production --- core/upgrader/impl/src/lib.rs | 24 ++-- .../impl/src/services/disaster_recovery.rs | 91 +++------------- core/upgrader/impl/src/upgrade.rs | 103 +++++++----------- 3 files changed, 70 insertions(+), 148 deletions(-) diff --git a/core/upgrader/impl/src/lib.rs b/core/upgrader/impl/src/lib.rs index e4fd1aeb1..c31bd72db 100644 --- a/core/upgrader/impl/src/lib.rs +++ b/core/upgrader/impl/src/lib.rs @@ -10,7 +10,7 @@ use ic_stable_structures::{ }; use lazy_static::lazy_static; use orbit_essentials::storable; -use std::{cell::RefCell, sync::Arc, thread::LocalKey}; +use std::{cell::RefCell, sync::Arc}; use upgrade::{UpgradeError, UpgradeParams}; use upgrader_api::{InitArg, TriggerUpgradeError}; @@ -29,7 +29,6 @@ pub mod utils; type Memory = VirtualMemory; type StableMap = StableBTreeMap; type StableValue = StableMap<(), T>; -type LocalRef = &'static LocalKey>; const MEMORY_ID_TARGET_CANISTER_ID: u8 = 0; const MEMORY_ID_DISASTER_RECOVERY: u8 = 1; @@ -51,6 +50,15 @@ thread_local! { ); } +pub fn get_target_canister() -> Principal { + TARGET_CANISTER_ID.with(|id| { + id.borrow() + .get(&()) + .map(|id| id.0) + .unwrap_or(Principal::anonymous()) + }) +} + #[init] fn init_fn(InitArg { target_canister }: InitArg) { TARGET_CANISTER_ID.with(|id| { @@ -61,13 +69,13 @@ fn init_fn(InitArg { target_canister }: InitArg) { lazy_static! { static ref UPGRADER: Box = { - let u = Upgrader::new(&TARGET_CANISTER_ID); - let u = WithStop(u, &TARGET_CANISTER_ID); - let u = WithStart(u, &TARGET_CANISTER_ID); + let u = Upgrader {}; + let u = WithStop(u); + let u = WithStart(u); let u = WithLogs(u, "upgrade".to_string()); - let u = WithBackground(Arc::new(u), &TARGET_CANISTER_ID); - let u = CheckController(u, &TARGET_CANISTER_ID); - let u = WithAuthorization(u, &TARGET_CANISTER_ID); + let u = WithBackground(Arc::new(u)); + let u = CheckController(u); + let u = WithAuthorization(u); let u = WithLogs(u, "trigger_upgrade".to_string()); Box::new(u) }; diff --git a/core/upgrader/impl/src/services/disaster_recovery.rs b/core/upgrader/impl/src/services/disaster_recovery.rs index 3ee5cded1..dd3ee7073 100644 --- a/core/upgrader/impl/src/services/disaster_recovery.rs +++ b/core/upgrader/impl/src/services/disaster_recovery.rs @@ -1,35 +1,29 @@ -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, - sync::Arc, -}; - +use super::{InstallCanister, LoggerService, INSTALL_CANISTER}; use crate::{ errors::UpgraderApiError, + get_target_canister, model::{ - Asset, DisasterRecoveryInProgressLog, DisasterRecoveryResultLog, DisasterRecoveryStartLog, - LogEntryType, MultiAssetAccount, RequestDisasterRecoveryLog, SetAccountsAndAssetsLog, - SetAccountsLog, SetCommitteeLog, + Account, AdminUser, Asset, DisasterRecovery, DisasterRecoveryCommittee, + DisasterRecoveryInProgressLog, DisasterRecoveryResultLog, DisasterRecoveryStartLog, + InstallMode, LogEntryType, MultiAssetAccount, RecoveryEvaluationResult, RecoveryFailure, + RecoveryResult, RecoveryStatus, RequestDisasterRecoveryLog, SetAccountsAndAssetsLog, + SetAccountsLog, SetCommitteeLog, StationRecoveryRequest, }, services::LOGGER_SERVICE, upgrader_ic_cdk::{api::time, spawn}, + StableValue, MEMORY_ID_DISASTER_RECOVERY, MEMORY_MANAGER, }; + use candid::Principal; use ic_stable_structures::memory_manager::MemoryId; use lazy_static::lazy_static; use orbit_essentials::{api::ServiceResult, utils::sha256_hash}; - -use crate::{ - model::{ - Account, AdminUser, DisasterRecovery, DisasterRecoveryCommittee, InstallMode, - RecoveryEvaluationResult, RecoveryFailure, RecoveryResult, RecoveryStatus, - StationRecoveryRequest, - }, - StableValue, MEMORY_ID_DISASTER_RECOVERY, MEMORY_MANAGER, TARGET_CANISTER_ID, +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + sync::Arc, }; -use super::{InstallCanister, LoggerService, INSTALL_CANISTER}; - pub const DISASTER_RECOVERY_REQUEST_EXPIRATION_NS: u64 = 60 * 60 * 24 * 7 * 1_000_000_000; // 1 week pub const DISASTER_RECOVERY_IN_PROGESS_EXPIRATION_NS: u64 = 60 * 60 * 1_000_000_000; // 1 hour @@ -299,15 +293,7 @@ impl DisasterRecoveryService { return; } - let Some(station_canister_id) = - TARGET_CANISTER_ID.with(|id| id.borrow().get(&()).map(|id| id.0)) - else { - value.last_recovery_result = Some(RecoveryResult::Failure(RecoveryFailure { - reason: "Station canister ID not set".to_string(), - })); - storage.set(value); - return; - }; + let station_canister_id = get_target_canister(); value.recovery_status = RecoveryStatus::InProgress { since: time() }; storage.set(value); @@ -432,7 +418,6 @@ mod tests { services::{ DisasterRecoveryService, DisasterRecoveryStorage, InstallCanister, LoggerService, }, - StorablePrincipal, TARGET_CANISTER_ID, }; #[derive(Default)] @@ -592,11 +577,6 @@ mod tests { #[tokio::test] async fn test_do_recovery() { - TARGET_CANISTER_ID.with(|id| { - id.borrow_mut() - .insert((), StorablePrincipal(Principal::anonymous())); - }); - let storage: DisasterRecoveryStorage = Default::default(); let logger = Arc::new(LoggerService::default()); let recovery_request = StationRecoveryRequest { @@ -672,51 +652,8 @@ mod tests { ); } - #[tokio::test] - async fn test_failing_do_recovery_with_no_target_canister_id() { - // setup: TARGET_CANISTER_ID is not set, so recovery should fail - - let storage: DisasterRecoveryStorage = Default::default(); - let logger = Arc::new(LoggerService::default()); - let recovery_request = StationRecoveryRequest { - user_id: [1; 16], - wasm_module: vec![1, 2, 3], - wasm_module_extra_chunks: None, - wasm_sha256: vec![4, 5, 6], - install_mode: InstallMode::Reinstall, - arg: vec![7, 8, 9], - arg_sha256: vec![10, 11, 12], - submitted_at: 0, - }; - - let installer = Arc::new(TestInstaller::default()); - - DisasterRecoveryService::do_recovery( - storage.clone(), - installer.clone(), - logger.clone(), - recovery_request.clone(), - ) - .await; - - assert!(matches!( - storage.get().last_recovery_result, - Some(RecoveryResult::Failure(_)) - )); - - assert!(matches!( - storage.get().recovery_status, - RecoveryStatus::Idle - )); - } - #[tokio::test] async fn test_failing_do_recovery_with_panicking_install() { - TARGET_CANISTER_ID.with(|id| { - id.borrow_mut() - .insert((), StorablePrincipal(Principal::anonymous())); - }); - let storage: DisasterRecoveryStorage = Default::default(); let logger = Arc::new(LoggerService::default()); let recovery_request = StationRecoveryRequest { diff --git a/core/upgrader/impl/src/upgrade.rs b/core/upgrader/impl/src/upgrade.rs index 3e4ec301b..0919ec901 100644 --- a/core/upgrader/impl/src/upgrade.rs +++ b/core/upgrader/impl/src/upgrade.rs @@ -1,11 +1,10 @@ use crate::{ + get_target_canister, model::{LogEntryType, UpgradeResultLog}, services::LOGGER_SERVICE, - LocalRef, StableValue, StorablePrincipal, }; -use anyhow::{anyhow, Context}; +use anyhow::anyhow; use async_trait::async_trait; -use candid::Principal; use ic_cdk::api::management_canister::main::{ self as mgmt, CanisterIdRecord, CanisterInfoRequest, CanisterInstallMode, }; @@ -41,23 +40,12 @@ pub trait Upgrade: 'static + Sync + Send { } #[derive(Clone)] -pub struct Upgrader { - target: LocalRef>, -} - -impl Upgrader { - pub fn new(target: LocalRef>) -> Self { - Self { target } - } -} +pub struct Upgrader {} #[async_trait] impl Upgrade for Upgrader { async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { - let target_canister = self - .target - .with(|id| id.borrow().get(&()).context("canister id not set"))? - .0; + let target_canister = get_target_canister(); install_chunked_code( target_canister, @@ -71,17 +59,15 @@ impl Upgrade for Upgrader { } } -pub struct WithStop(pub T, pub LocalRef>); +pub struct WithStop(pub T); #[async_trait] impl Upgrade for WithStop { /// Perform an upgrade but ensure that the target canister is stopped first async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { - let id = self - .1 - .with(|id| id.borrow().get(&()).context("canister id not set"))?; + let id = get_target_canister(); - mgmt::stop_canister(CanisterIdRecord { canister_id: id.0 }) + mgmt::stop_canister(CanisterIdRecord { canister_id: id }) .await .map_err(|(_, err)| anyhow!("failed to stop canister: {err}"))?; @@ -89,7 +75,7 @@ impl Upgrade for WithStop { } } -pub struct WithStart(pub T, pub LocalRef>); +pub struct WithStart(pub T); #[async_trait] impl Upgrade for WithStart { @@ -98,11 +84,9 @@ impl Upgrade for WithStart { async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { let out = self.0.upgrade(ps).await; - let id = self - .1 - .with(|id| id.borrow().get(&()).context("canister id not set"))?; + let id = get_target_canister(); - mgmt::start_canister(CanisterIdRecord { canister_id: id.0 }) + mgmt::start_canister(CanisterIdRecord { canister_id: id }) .await .map_err(|(_, err)| anyhow!("failed to start canister: {err}"))?; @@ -110,7 +94,7 @@ impl Upgrade for WithStart { } } -pub struct WithBackground(pub Arc, pub LocalRef>); +pub struct WithBackground(pub Arc); #[async_trait] impl Upgrade for WithBackground { @@ -118,37 +102,34 @@ impl Upgrade for WithBackground { /// so that it is performed in a non-blocking manner async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { let u = self.0.clone(); - let target_canister_id: Option = - self.1.with(|p| p.borrow().get(&()).map(|sp| sp.0)); + let target_canister_id = get_target_canister(); ic_cdk::spawn(async move { let res = u.upgrade(ps).await; // Notify the target canister about a failed upgrade unless the call is unauthorized // (we don't want to spam the target canister with such errors). - if let Some(target_canister_id) = target_canister_id { - if let Err(ref err) = res { - let err = match err { - UpgradeError::UnexpectedError(err) => Some(err.to_string()), - UpgradeError::NotController => Some( - "The upgrader canister is not a controller of the target canister" - .to_string(), - ), - UpgradeError::Unauthorized => None, - }; - if let Some(err) = err { - let notify_failed_station_upgrade_input = - NotifyFailedStationUpgradeInput { reason: err }; - let notify_res = call::<_, (ApiResult<()>,)>( - target_canister_id, - "notify_failed_station_upgrade", - (notify_failed_station_upgrade_input,), - ) - .await - .map(|r| r.0); - // Log an error if the notification can't be made. - if let Err(e) = notify_res { - print(format!("notify_failed_station_upgrade failed: {:?}", e)); - } + if let Err(ref err) = res { + let err = match err { + UpgradeError::UnexpectedError(err) => Some(err.to_string()), + UpgradeError::NotController => Some( + "The upgrader canister is not a controller of the target canister" + .to_string(), + ), + UpgradeError::Unauthorized => None, + }; + if let Some(err) = err { + let notify_failed_station_upgrade_input = + NotifyFailedStationUpgradeInput { reason: err }; + let notify_res = call::<_, (ApiResult<()>,)>( + target_canister_id, + "notify_failed_station_upgrade", + (notify_failed_station_upgrade_input,), + ) + .await + .map(|r| r.0); + // Log an error if the notification can't be made. + if let Err(e) = notify_res { + print(format!("notify_failed_station_upgrade failed: {:?}", e)); } } } @@ -158,16 +139,14 @@ impl Upgrade for WithBackground { } } -pub struct WithAuthorization(pub T, pub LocalRef>); +pub struct WithAuthorization(pub T); #[async_trait] impl Upgrade for WithAuthorization { async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { - let id = self - .1 - .with(|id| id.borrow().get(&()).context("canister id not set"))?; + let id = get_target_canister(); - if !ic_cdk::caller().eq(&id.0) { + if !ic_cdk::caller().eq(&id) { return Err(UpgradeError::Unauthorized); } @@ -175,17 +154,15 @@ impl Upgrade for WithAuthorization { } } -pub struct CheckController(pub T, pub LocalRef>); +pub struct CheckController(pub T); #[async_trait] impl Upgrade for CheckController { async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> { - let id = self - .1 - .with(|id| id.borrow().get(&()).context("canister id not set"))?; + let id = get_target_canister(); let (resp,) = mgmt::canister_info(CanisterInfoRequest { - canister_id: id.0, + canister_id: id, num_requested_changes: None, }) .await