Skip to content

Commit

Permalink
feat(upgrader): new stable memory layout
Browse files Browse the repository at this point in the history
  • Loading branch information
mraszyk committed Jan 10, 2025
1 parent 2fe96cb commit d74ab4c
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 174 deletions.
148 changes: 122 additions & 26 deletions core/upgrader/impl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use crate::model::{DisasterRecovery, LogEntry, State};
use crate::services::set_logs;
use crate::upgrade::{
CheckController, Upgrade, Upgrader, WithAuthorization, WithBackground, WithLogs, WithStart,
WithStop,
};
use candid::Principal;
use ic_cdk::{api::management_canister::main::CanisterInstallMode, init, update};
use ic_cdk::api::stable::stable_write;
use ic_cdk::{
api::management_canister::main::CanisterInstallMode, init, post_upgrade, trap, update,
};
use ic_stable_structures::{
memory_manager::{MemoryId, MemoryManager, VirtualMemory},
DefaultMemoryImpl, StableBTreeMap,
storable::Bound,
DefaultMemoryImpl, StableBTreeMap, Storable,
};
use lazy_static::lazy_static;
use orbit_essentials::storable;
use std::{cell::RefCell, sync::Arc, thread::LocalKey};
use orbit_essentials::types::Timestamp;
use std::{borrow::Cow, cell::RefCell, collections::BTreeMap, sync::Arc};
use upgrade::{UpgradeError, UpgradeParams};
use upgrader_api::{InitArg, TriggerUpgradeError};

Expand All @@ -29,45 +35,135 @@ pub mod utils;
type Memory = VirtualMemory<DefaultMemoryImpl>;
type StableMap<K, V> = StableBTreeMap<K, V, Memory>;
type StableValue<T> = StableMap<(), T>;
type LocalRef<T> = &'static LocalKey<RefCell<T>>;

const MEMORY_ID_TARGET_CANISTER_ID: u8 = 0;
const MEMORY_ID_DISASTER_RECOVERY: u8 = 1;
const MEMORY_ID_LOGS: u8 = 4;
/// Represents one mebibyte.
pub const MIB: u32 = 1 << 20;

thread_local! {
static MEMORY_MANAGER: RefCell<MemoryManager<DefaultMemoryImpl>> =
RefCell::new(MemoryManager::init(DefaultMemoryImpl::default()));
}
/// Canisters use 64KiB pages for Wasm memory, more details in the PR that introduced this constant:
/// - https://github.com/WebAssembly/design/pull/442#issuecomment-153203031
pub const WASM_PAGE_SIZE: u32 = 65536;

#[storable]
pub struct StorablePrincipal(Principal);
/// The size of the stable memory bucket in WASM pages.
///
/// We use a bucket size of 1MiB to ensure that the default memory allocated to the canister is as small as possible,
/// this is due to the fact that this cansiter uses several MemoryIds to manage the stable memory similarly to to how
/// a database arranges data per table.
///
/// Currently a bucket size of 1MiB limits the canister to 32GiB of stable memory, which is more than enough for the
/// current use case, however, if the canister needs more memory in the future, `ic-stable-structures` will need to be
/// updated to support storing more buckets in a backwards compatible way.
pub const STABLE_MEMORY_BUCKET_SIZE: u16 = (MIB / WASM_PAGE_SIZE) as u16;

/// Current version of stable memory layout.
pub const STABLE_MEMORY_VERSION: u32 = 1;

const MEMORY_ID_STATE: u8 = 0;
const MEMORY_ID_LOGS: u8 = 1;

thread_local! {
static TARGET_CANISTER_ID: RefCell<StableValue<StorablePrincipal>> = RefCell::new(
static MEMORY_MANAGER: RefCell<MemoryManager<DefaultMemoryImpl>> =
RefCell::new(MemoryManager::init_with_bucket_size(DefaultMemoryImpl::default(), STABLE_MEMORY_BUCKET_SIZE));
static STATE: RefCell<StableValue<State>> = RefCell::new(
StableValue::init(
MEMORY_MANAGER.with(|m| m.borrow().get(MemoryId::new(MEMORY_ID_TARGET_CANISTER_ID))),
MEMORY_MANAGER.with(|m| m.borrow().get(MemoryId::new(MEMORY_ID_STATE))),
)
);
}

fn get_state() -> State {
STATE.with(|storage| storage.borrow().get(&()).unwrap_or_default())
}

fn set_state(state: State) {
STATE.with(|storage| storage.borrow_mut().insert((), state));
}

pub fn get_target_canister() -> Principal {
get_state().target_canister
}

pub fn set_target_canister(target_canister: Principal) {
let mut state = get_state();
state.target_canister = target_canister;
set_state(state);
}

pub fn get_disaster_recovery() -> DisasterRecovery {
get_state().disaster_recovery
}

pub fn set_disaster_recovery(value: DisasterRecovery) {
let mut state = get_state();
state.disaster_recovery = value;
set_state(state);
}

#[init]
fn init_fn(InitArg { target_canister }: InitArg) {
TARGET_CANISTER_ID.with(|id| {
let mut id = id.borrow_mut();
id.insert((), StorablePrincipal(target_canister));
});
set_target_canister(target_canister);
}

#[post_upgrade]
fn post_upgrade() {
pub struct RawBytes(pub Vec<u8>);
impl Storable for RawBytes {
fn to_bytes(&self) -> Cow<[u8]> {
trap("RawBytes should never be serialized")
}

fn from_bytes(bytes: Cow<[u8]>) -> Self {
Self(bytes.to_vec())
}

const BOUND: Bound = Bound::Unbounded;
}

const OLD_MEMORY_ID_TARGET_CANISTER_ID: u8 = 0;
const OLD_MEMORY_ID_DISASTER_RECOVERY: u8 = 1;
const OLD_MEMORY_ID_LOGS: u8 = 4;

let old_memory_manager = MemoryManager::init(DefaultMemoryImpl::default());

// determine stable memory layout by trying to parse the target canister id from OLD_MEMORY_ID_TARGET_CANISTER_ID
let old_target_canister_bytes: StableValue<RawBytes> =
StableValue::init(old_memory_manager.get(MemoryId::new(OLD_MEMORY_ID_TARGET_CANISTER_ID)));
let target_canister_bytes = old_target_canister_bytes
.get(&())
.unwrap_or_else(|| trap("Could not determine stable memory layout."));
if let Ok(target_canister) = serde_cbor::from_slice::<Principal>(&target_canister_bytes.0) {
let old_disaster_recovery: StableValue<DisasterRecovery> = StableValue::init(
old_memory_manager.get(MemoryId::new(OLD_MEMORY_ID_DISASTER_RECOVERY)),
);
let disaster_recovery: DisasterRecovery =
old_disaster_recovery.get(&()).unwrap_or_default();

let old_logs: StableBTreeMap<Timestamp, LogEntry, Memory> =
StableBTreeMap::init(old_memory_manager.get(MemoryId::new(OLD_MEMORY_ID_LOGS)));
let logs: BTreeMap<Timestamp, LogEntry> = old_logs.iter().collect();

// clear the magic header of stable structures to force their reinitialization
// https://github.com/dfinity/stable-structures/blob/69ed47f9b5001af67d650c714cd56ec3ee0ef2bb/src/memory_manager.rs#L254-L256
stable_write(0, &[0; 3]);

let state = State {
target_canister,
disaster_recovery,
stable_memory_version: STABLE_MEMORY_VERSION,
};
set_state(state);
set_logs(logs);
}
}

lazy_static! {
static ref UPGRADER: Box<dyn Upgrade> = {
let u = Upgrader::new(&TARGET_CANISTER_ID);
let u = WithStop(u, &TARGET_CANISTER_ID);
let u = WithStart(u, &TARGET_CANISTER_ID);
let u = Upgrader {};
let u = WithStop(u);
let u = WithStart(u);
let u = WithLogs(u, "upgrade".to_string());
let u = WithBackground(Arc::new(u), &TARGET_CANISTER_ID);
let u = CheckController(u, &TARGET_CANISTER_ID);
let u = WithAuthorization(u, &TARGET_CANISTER_ID);
let u = WithBackground(Arc::new(u));
let u = CheckController(u);
let u = WithAuthorization(u);
let u = WithLogs(u, "trigger_upgrade".to_string());
Box::new(u)
};
Expand Down
21 changes: 21 additions & 0 deletions core/upgrader/impl/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,24 @@ mod logging;

pub use disaster_recovery::*;
pub use logging::*;

use crate::STABLE_MEMORY_VERSION;
use candid::Principal;
use orbit_essentials::storable;

#[storable]
pub struct State {
pub target_canister: Principal,
pub disaster_recovery: DisasterRecovery,
pub stable_memory_version: u32,
}

impl Default for State {
fn default() -> Self {
Self {
target_canister: Principal::anonymous(),
disaster_recovery: Default::default(),
stable_memory_version: STABLE_MEMORY_VERSION,
}
}
}
89 changes: 10 additions & 79 deletions core/upgrader/impl/src/services/disaster_recovery.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,32 @@
use std::{cell::RefCell, collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc};

use crate::{
errors::UpgraderApiError,
get_disaster_recovery, get_target_canister,
model::{
Asset, DisasterRecoveryInProgressLog, DisasterRecoveryResultLog, DisasterRecoveryStartLog,
LogEntryType, MultiAssetAccount, RequestDisasterRecoveryLog, SetAccountsAndAssetsLog,
SetAccountsLog, SetCommitteeLog,
},
services::LOGGER_SERVICE,
set_disaster_recovery,
upgrader_ic_cdk::{api::time, spawn},
};
use candid::Principal;
use ic_stable_structures::memory_manager::MemoryId;
use lazy_static::lazy_static;
use orbit_essentials::{api::ServiceResult, utils::sha256_hash};

use crate::{
model::{
Account, AdminUser, DisasterRecovery, DisasterRecoveryCommittee, InstallMode,
RecoveryEvaluationResult, RecoveryFailure, RecoveryResult, RecoveryStatus,
StationRecoveryRequest,
},
StableValue, MEMORY_ID_DISASTER_RECOVERY, MEMORY_MANAGER, TARGET_CANISTER_ID,
use crate::model::{
Account, AdminUser, DisasterRecovery, DisasterRecoveryCommittee, InstallMode,
RecoveryEvaluationResult, RecoveryFailure, RecoveryResult, RecoveryStatus,
StationRecoveryRequest,
};

use super::{InstallCanister, LoggerService, INSTALL_CANISTER};

pub const DISASTER_RECOVERY_REQUEST_EXPIRATION_NS: u64 = 60 * 60 * 24 * 7 * 1_000_000_000; // 1 week
pub const DISASTER_RECOVERY_IN_PROGESS_EXPIRATION_NS: u64 = 60 * 60 * 1_000_000_000; // 1 hour

thread_local! {

static STORAGE: RefCell<StableValue<DisasterRecovery>> = RefCell::new(
StableValue::init(
MEMORY_MANAGER.with(|m| m.borrow().get(MemoryId::new(MEMORY_ID_DISASTER_RECOVERY))),
)
);

}

lazy_static! {
pub static ref DISASTER_RECOVERY_SERVICE: Arc<DisasterRecoveryService> =
Arc::new(DisasterRecoveryService {
Expand Down Expand Up @@ -83,11 +71,11 @@ pub struct DisasterRecoveryStorage {}

impl DisasterRecoveryStorage {
pub fn get(&self) -> DisasterRecovery {
STORAGE.with(|storage| storage.borrow().get(&()).unwrap_or_default())
get_disaster_recovery()
}

fn set(&self, value: DisasterRecovery) {
STORAGE.with(|storage| storage.borrow_mut().insert((), value));
set_disaster_recovery(value);
}
}

Expand Down Expand Up @@ -282,15 +270,7 @@ impl DisasterRecoveryService {
return;
}

let Some(station_canister_id) =
TARGET_CANISTER_ID.with(|id| id.borrow().get(&()).map(|id| id.0))
else {
value.last_recovery_result = Some(RecoveryResult::Failure(RecoveryFailure {
reason: "Station canister ID not set".to_string(),
}));
storage.set(value);
return;
};
let station_canister_id = get_target_canister();

value.recovery_status = RecoveryStatus::InProgress { since: time() };
storage.set(value);
Expand Down Expand Up @@ -415,7 +395,6 @@ mod tests {
services::{
DisasterRecoveryService, DisasterRecoveryStorage, InstallCanister, LoggerService,
},
StorablePrincipal, TARGET_CANISTER_ID,
};

#[derive(Default)]
Expand Down Expand Up @@ -575,11 +554,6 @@ mod tests {

#[tokio::test]
async fn test_do_recovery() {
TARGET_CANISTER_ID.with(|id| {
id.borrow_mut()
.insert((), StorablePrincipal(Principal::anonymous()));
});

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
Expand Down Expand Up @@ -655,51 +629,8 @@ mod tests {
);
}

#[tokio::test]
async fn test_failing_do_recovery_with_no_target_canister_id() {
// setup: TARGET_CANISTER_ID is not set, so recovery should fail

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
user_id: [1; 16],
wasm_module: vec![1, 2, 3],
wasm_module_extra_chunks: None,
wasm_sha256: vec![4, 5, 6],
install_mode: InstallMode::Reinstall,
arg: vec![7, 8, 9],
arg_sha256: vec![10, 11, 12],
submitted_at: 0,
};

let installer = Arc::new(TestInstaller::default());

DisasterRecoveryService::do_recovery(
storage.clone(),
installer.clone(),
logger.clone(),
recovery_request.clone(),
)
.await;

assert!(matches!(
storage.get().last_recovery_result,
Some(RecoveryResult::Failure(_))
));

assert!(matches!(
storage.get().recovery_status,
RecoveryStatus::Idle
));
}

#[tokio::test]
async fn test_failing_do_recovery_with_panicking_install() {
TARGET_CANISTER_ID.with(|id| {
id.borrow_mut()
.insert((), StorablePrincipal(Principal::anonymous()));
});

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
Expand Down
17 changes: 13 additions & 4 deletions core/upgrader/impl/src/services/logger.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{cell::RefCell, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, sync::Arc};

use ic_stable_structures::{memory_manager::MemoryId, BTreeMap};
use ic_stable_structures::{memory_manager::MemoryId, StableBTreeMap};
use lazy_static::lazy_static;
use orbit_essentials::types::Timestamp;

Expand All @@ -14,13 +14,22 @@ pub const DEFAULT_GET_LOGS_LIMIT: u64 = 10;
pub const MAX_LOG_ENTRIES: u64 = 25000;

thread_local! {
static STORAGE: RefCell<BTreeMap<Timestamp, LogEntry, Memory>> = RefCell::new(
BTreeMap::init(
static STORAGE: RefCell<StableBTreeMap<Timestamp, LogEntry, Memory>> = RefCell::new(
StableBTreeMap::init(
MEMORY_MANAGER.with(|m| m.borrow().get(MemoryId::new(MEMORY_ID_LOGS))),
)
);
}

// only use this function for stable memory migration!
pub fn set_logs(logs: BTreeMap<Timestamp, LogEntry>) {
STORAGE.with(|storage| {
for (timestamp, log) in logs {
storage.borrow_mut().insert(timestamp, log);
}
});
}

lazy_static! {
pub static ref LOGGER_SERVICE: Arc<LoggerService> = Arc::new(LoggerService::default());
}
Expand Down
Loading

0 comments on commit d74ab4c

Please sign in to comment.