Skip to content

Commit

Permalink
feat: add "persist" feature
Browse files Browse the repository at this point in the history
This patch adds a new feature "persist". Using this feature, the upper layer can save and restore the state of the struct (for example, `Vfs` and `PseudoFs`) which implement the `crate::api::persist::Snapshotter` trait.

This feature introduces a new trait `create::api::persist::Snapshotter`, which has two methods:
- `fn save_to_bytes(&self) -> Result<Vec<u8>>` which saves the state of the struct to a byte array
- ` fn load_from_bytes(constructor_args: Self::ConstructorArgs, buf: &mut Vec<u8>) -> Result<Self>` which restores the state of the struct from a byte array

The `Snapshotter` trait uses [the `Snapshot` crate](https://github.com/firecracker-microvm/firecracker/tree/main/src/snapshot) to serialize and deserialize the struct data. Therefore, the struct which implement the `Snapshotter` trait must implement the `snapshot::Persist` trait and implement the `create::api::persist::VersionManager` trait to define it's versions.

Signed-off-by: Nan Li <[email protected]>
  • Loading branch information
loheagn committed Sep 18, 2023
1 parent 8c89657 commit 0d1d50f
Show file tree
Hide file tree
Showing 5 changed files with 632 additions and 3 deletions.
23 changes: 20 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ build = "build.rs"
arc-swap = "1.5"
async-trait = { version = "0.1.42", optional = true }
bitflags = "1.1"
snapshot = { git = "https://github.com/firecracker-microvm/firecracker", tag = "v1.4.1", optional = true }
io-uring = { version = "0.5.8", optional = true }
libc = "0.2.68"
log = "0.4.6"
mio = { version = "0.8", features = ["os-poll", "os-ext"]}
mio = { version = "0.8", features = ["os-poll", "os-ext"] }
nix = "0.24"
lazy_static = "1.4"
tokio = { version = "1", optional = true }
Expand All @@ -32,6 +33,8 @@ vmm-sys-util = { version = "0.11", optional = true }
vm-memory = { version = "0.10", features = ["backend-mmap"] }
virtio-queue = { version = "0.7", optional = true }
vhost = { version = "0.6", features = ["vhost-user-slave"], optional = true }
versionize_derive = { version = "0.1.6", optional = true }
versionize = { version = "0.1.10", optional = true }

[target.'cfg(target_os = "macos")'.dependencies]
core-foundation-sys = { version = ">=0.8", optional = true }
Expand All @@ -47,11 +50,25 @@ vm-memory = { version = "0.10", features = ["backend-mmap", "backend-bitmap"] }

[features]
default = ["fusedev"]
async-io = ["async-trait", "tokio-uring", "tokio/fs", "tokio/net", "tokio/sync", "tokio/rt", "tokio/macros", "io-uring"]
async-io = [
"async-trait",
"tokio-uring",
"tokio/fs",
"tokio/net",
"tokio/sync",
"tokio/rt",
"tokio/macros",
"io-uring",
]
fusedev = ["vmm-sys-util", "caps", "core-foundation-sys"]
virtiofs = ["virtio-queue", "caps", "vmm-sys-util"]
vhost-user-fs = ["virtiofs", "vhost", "caps"]
persist = ["snapshot", "versionize", "versionize_derive"]

[package.metadata.docs.rs]
all-features = true
targets = ["x86_64-unknown-linux-gnu", "aarch64-unknown-linux-gnu", "aarch64-apple-darwin"]
targets = [
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
"aarch64-apple-darwin",
]
4 changes: 4 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ pub use vfs::{

pub mod filesystem;
pub mod server;

/// The module is used to serialize and deserialize data (for example, the `Vfs` and `PseudoFs`).
#[cfg(feature = "persist")]
pub mod persist;
160 changes: 160 additions & 0 deletions src/api/persist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use std::{
any::TypeId,
collections::HashMap,
fmt::Debug,
io::{Error as IoError, ErrorKind, Result},
};

use snapshot::{Persist, Snapshot};
use versionize::{VersionMap, Versionize};

/// A list of versions.
type Versions = Vec<HashMap<TypeId, u16>>;

/// Version Manager trait
pub trait VersionManger {
/// Returns a list of versions.
fn get_versions() -> Versions;

/// Returns a `VersionMap` with the versions defined by `get_versions`.
fn new_version_map() -> VersionMap {
let mut version_map = VersionMap::new();
for (idx, map) in Self::get_versions().into_iter().enumerate() {
if idx > 0 {
version_map.new_version();
}
for (type_id, version) in map {
version_map.set_type_version(type_id, version);
}
}
version_map
}

/// Returns a new `Snapshot` with the versions defined by `get_versions`.
fn new_snapshot() -> Snapshot {
let vm = Self::new_version_map();
let target_version = vm.latest_version();
Snapshot::new(vm, target_version)
}
}

/// Snapshotter trait
pub trait Snapshotter<'a>: Persist<'a>
where
Self::State: Versionize + VersionManger,
Self::Error: Debug,
{
/// Serializes `self` to a byte array.
fn save_to_bytes(&self) -> Result<Vec<u8>> {
let state = self.save();
let mut buf = Vec::new();
let mut snapshot = Self::State::new_snapshot();
snapshot.save(&mut buf, &state).map_err(|e| {
IoError::new(
ErrorKind::Other,
format!("Failed to save snapshot: {:?}", e),
)
})?;

Ok(buf)
}

/// Restores `self` from a byte array.
fn load_from_bytes(constructor_args: Self::ConstructorArgs, buf: &mut Vec<u8>) -> Result<Self> {
let state: Self::State = Snapshot::load(
&mut buf.as_slice(),
buf.len(),
Self::State::new_version_map(),
)
.map_err(|e| {
IoError::new(
ErrorKind::Other,
format!("Failed to load snapshot: {:?}", e),
)
})?;
let restored_self = Self::restore(constructor_args, &state).map_err(|e| {
IoError::new(
ErrorKind::Other,
format!("Failed to restore snapshot: {:?}", e),
)
})?;

Ok(restored_self)
}
}

impl<'a, T: Persist<'a>> Snapshotter<'a> for T
where
T::State: Versionize + VersionManger,
T::Error: Debug,
{
}

mod test {
use std::collections::HashMap;

use snapshot::Persist;
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;

use super::VersionManger;

#[derive(Debug, PartialEq)]
struct Test {
a: u32,
b: u32,
}

#[derive(Debug, Versionize)]
struct TestState {
a: u32,
b: u32,
}

impl VersionManger for TestState {
fn get_versions() -> super::Versions {
vec![HashMap::from([(std::any::TypeId::of::<u32>(), 1)])]
}
}

impl Persist<'_> for Test {
type State = TestState;

type ConstructorArgs = ();

type Error = std::io::Error;

fn save(&self) -> Self::State {
TestState {
a: self.a,
b: self.b,
}
}

fn restore(
_constructor_args: Self::ConstructorArgs,
state: &Self::State,
) -> std::result::Result<Self, Self::Error> {
Ok(Test {
a: state.a,
b: state.b,
})
}
}

#[test]
fn save_load_test() {
use crate::api::persist::Snapshotter;

let t = Test { a: 1u32, b: 4u32 };

// save
let mut buf = t.save_to_bytes().unwrap();

// restore
let restored_t = Test::load_from_bytes((), &mut buf).unwrap();

// assert
assert_eq!(t, restored_t);
}
}
156 changes: 156 additions & 0 deletions src/api/pseudo_fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,162 @@ impl FileSystem for PseudoFs {
}
}

/// Save and restore PseudoFs state.
#[cfg(feature = "persist")]
pub mod persist {
use std::any::TypeId;
use std::collections::HashMap;
use std::io::{Error as IoError, ErrorKind};
use std::sync::atomic::Ordering;
use std::sync::Arc;

use snapshot::Persist;
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;

use super::{PseudoFs, PseudoInode};
use crate::api::filesystem::ROOT_ID;
use crate::api::persist::VersionManger;

#[derive(Versionize, PartialEq, Debug, Default, Clone)]
struct PseudoInodeState {
ino: u64,
parent: u64,
name: String,
}

#[derive(Versionize, PartialEq, Debug, Default)]
pub struct PseudoFsState {
next_inode: u64,
inodes: Vec<PseudoInodeState>,
}

impl VersionManger for PseudoFsState {
fn get_versions() -> Vec<HashMap<TypeId, u16>> {
let mut versions = vec![];

// version 1
versions.push(HashMap::from([(TypeId::of::<PseudoFsState>(), 1)]));

// more versions for the future

versions
}
}

impl<'a> Persist<'a> for &'a PseudoFs {
type State = PseudoFsState;

type ConstructorArgs = &'a PseudoFs;

type Error = IoError;

fn save(&self) -> Self::State {
let mut inodes = Vec::new();
let next_inode = self.next_inode.load(Ordering::Relaxed);

let _guard = self.lock.lock().unwrap();
for inode in self.inodes.load().values() {
if inode.ino == ROOT_ID {
// no need to save the root inode
continue;
}

inodes.push(PseudoInodeState {
ino: inode.ino,
parent: inode.parent,
name: inode.name.clone(),
});
}

PseudoFsState { next_inode, inodes }
}

fn restore(fs: Self::ConstructorArgs, state: &Self::State) -> Result<Self, Self::Error> {
// first, reconstruct all the inodes
let mut inode_map = HashMap::new();
let mut state_inodes = state.inodes.clone();
for inode in state_inodes.iter() {
let inode = Arc::new(PseudoInode::new(
inode.ino,
inode.parent,
inode.name.clone(),
));
inode_map.insert(inode.ino, inode);
}

// insert root inode to make sure the others inodes can find their parents
inode_map.insert(fs.root_inode.ino, fs.root_inode.clone());

// then, connect the inodes
state_inodes.sort_by(|a, b| a.ino.cmp(&b.ino));
for inode in state_inodes.iter() {
let inode = inode_map
.get(&inode.ino)
.ok_or_else(|| {
IoError::new(
ErrorKind::InvalidData,
format!("invalid inode {}", inode.ino),
)
})?
.clone();
let parent = inode_map.get_mut(&inode.parent).ok_or_else(|| {
IoError::new(
ErrorKind::InvalidData,
format!(
"invalid parent inode {} for inode {}",
inode.parent, inode.ino
),
)
})?;
parent.insert_child(inode);
}
fs.inodes.store(Arc::new(inode_map));

// last, restore next_inode
fs.next_inode.store(state.next_inode, Ordering::Relaxed);

Ok(fs)
}
}

mod test {

#[test]
fn save_restore_test() {
use crate::api::persist::Snapshotter;
use crate::api::pseudo_fs::PseudoFs;

let fs = &PseudoFs::new();
let paths = vec!["/a", "/a/b", "/a/b/c", "/b", "/b/a/c", "/d"];

for path in paths.iter() {
fs.mount(path).unwrap();
}

// save fs
let mut buf = fs.save_to_bytes().unwrap();

// restore fs
let restored_fs = &PseudoFs::new();
let restored_fs = <&PseudoFs>::load_from_bytes(restored_fs, &mut buf).unwrap();

// check fs and restored_fs
let next_inode = fs.next_inode.load(std::sync::atomic::Ordering::Relaxed);
let restored_next_inode = restored_fs
.next_inode
.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(next_inode, restored_next_inode);

for path in paths.iter() {
let inode = fs.path_walk(path).unwrap();
let restored_inode = restored_fs.path_walk(path).unwrap();
assert_eq!(inode, restored_inode);
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 0d1d50f

Please sign in to comment.