Skip to content

Commit

Permalink
feat: upstream ops from WASI-virt
Browse files Browse the repository at this point in the history
WASI-virt contains functions that are helpful for manipulating modules
and dealing with exports/imports, which would be helpful to an even
wider group if upstreamed here to walrus.

This commit copies and upstreams some operations that were introduced
in WASI-virt for wider use via walrus.

See also: bytecodealliance/WASI-Virt#20

Signed-off-by: Victor Adossi <[email protected]>
  • Loading branch information
vados-cosmonic committed Oct 13, 2023
1 parent 440dc03 commit 4331cb2
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 8 deletions.
22 changes: 22 additions & 0 deletions src/module/exports.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Exported items in a wasm module.

use anyhow::Context;

use crate::emit::{Emit, EmitContext};
use crate::parse::IndicesToIds;
use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
Expand Down Expand Up @@ -100,6 +102,16 @@ impl ModuleExports {
})
}

/// Retrieve an exported function by name
pub fn get_func_by_name(&self, name: impl AsRef<str>) -> Result<FunctionId> {
self.iter()
.find_map(|expt| match expt.item {
ExportItem::Function(fid) if expt.name == name.as_ref() => Some(fid),
_ => None,
})
.with_context(|| format!("unable to find function export '{}'", name.as_ref()))
}

/// Get a reference to a table export given its table id.
pub fn get_exported_table(&self, t: TableId) -> Option<&Export> {
self.iter().find(|e| match e.item {
Expand Down Expand Up @@ -354,6 +366,16 @@ mod tests {
assert!(module.exports.get_exported_func(fn_id).is_none());
}

#[test]
fn get_func_by_name() {
let mut module = Module::default();
let fn_id: FunctionId = always_the_same_id();
let export_id: ExportId = module.exports.add("dummy", fn_id);
assert!(module.exports.get_func_by_name("dummy").is_ok());
module.exports.delete(export_id);
assert!(module.exports.get_func_by_name("dummy").is_err());
}

#[test]
fn iter_mut_can_update_export_item() {
let mut module = Module::default();
Expand Down
326 changes: 319 additions & 7 deletions src/module/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
//! Functions within a wasm module.

use std::cmp;
use std::collections::BTreeMap;

use anyhow::{bail, Context};
use wasm_encoder::Encode;
use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

mod local_function;

use crate::emit::{Emit, EmitContext};
Expand All @@ -11,19 +21,19 @@ use crate::parse::IndicesToIds;
use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
use crate::ty::TypeId;
use crate::ty::ValType;
use std::cmp;
use std::collections::BTreeMap;
use wasm_encoder::Encode;
use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources};

#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::{ExportItem, Memory, MemoryId};

pub use self::local_function::LocalFunction;

/// A function identifier.
pub type FunctionId = Id<Function>;

/// Parameter(s) to a function
pub type FuncParams = Vec<ValType>;

/// Result(s) of a given function
pub type FuncResults = Vec<ValType>;

/// A wasm function.
///
/// Either defined locally or externally and then imported; see `FunctionKind`.
Expand Down Expand Up @@ -418,6 +428,119 @@ impl Module {

Ok(())
}

/// Retrieve the ID for the first exported memory.
///
/// This method does not work in contexts with [multi-memory enabled](https://github.com/WebAssembly/multi-memory),
/// and will error if more than one memory is present.
pub fn get_memory_id(&self) -> Result<MemoryId> {
if self.memories.len() > 1 {
bail!("multiple memories unsupported")
}

self.memories
.iter()
.next()
.map(Memory::id)
.context("module does not export a memory")
}

/// Replace a single exported function with the result of the provided builder function.
///
/// For example, if you wanted to replace an exported function with a no-op,
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
///
/// module.replace_exported_func(fid, move || {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// This function returns the function ID of the *new* function,
/// after it has been inserted into the module as an export.
pub fn replace_exported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
match (self.exports.get_exported_func(fid), self.funcs.get(fid)) {
(
Some(exported_fn),
Function {
kind: FunctionKind::Local(lf),
..
},
) => {
// Retrieve the params & result types for the exported (local) function
let ty = self.types.get(lf.ty());
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Add the function produced by `fn_builder` as a local function,
let new_fid = self.funcs.add_local(
fn_builder((&params, &results)).context("export fn builder failed")?,
);

// Mutate the existing export to use the new local function
let export = self.exports.get_mut(exported_fn.id());
export.item = ExportItem::Function(new_fid);

Ok(new_fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an exported function"),
}
}

/// Replace a single imported function with the result of the provided builder function.
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
///
/// module.replace_imported_func(fid, move || {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// This function returns the function ID of the *new* function, and
/// removes the existing import that has been replaced (the function will become local).
pub fn replace_imported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
// If the function is in the imports, replace it
match (self.imports.get_imported_func(fid), self.funcs.get(fid)) {
(
Some(original_imported_fn),
Function {
kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
..
},
) => {
// Retrieve the params & result types for the imported function
let ty = self.types.get(*tid);
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Mutate the existing function, changing it from a FunctionKind::ImportedFunction
// to the local function produced by running the provided `fn_builder`
let func = self.funcs.get_mut(fid);
func.kind = FunctionKind::Local(
fn_builder((&params, &results)).context("import fn builder failed")?,
);

self.imports.delete(original_imported_fn.id());

Ok(fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an imported function"),
}
}
}

fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> {
Expand Down Expand Up @@ -535,3 +658,192 @@ impl Emit for ModuleFunctions {
cx.code_transform.instruction_map = instruction_map.into_iter().collect();
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{Export, Module, FunctionBuilder};

#[test]
fn get_memory_id() {
let mut module = Module::default();
let expected_id = module.memories.add_local(false, 0, None);
assert!(module.get_memory_id().is_ok_and(|id| id == expected_id));
}

/// Running `replace_exported_func` with a closure that builds
/// a function should replace the existing function with the new one
#[test]
fn replace_exported_func() {
let mut module = Module::default();

// Create original function
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
builder.func_body().i32_const(1234).drop();
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
})
.expect("function replacement worked");

assert!(
module.exports.get_exported_func(original_fn_id).is_none(),
"replaced function cannot be gotten by ID"
);

// Ensure the function was replaced
match module
.exports
.get_exported_func(new_fn_id)
.expect("failed to unwrap exported func")
{
exp @ Export {
item: ExportItem::Function(fid),
..
} => {
assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
}
_ => panic!("expected an Export with a Function inside"),
}
}

/// Running `replace_exported_func` with a closure that returns None
/// should replace the function with a generated no-op function
#[test]
fn replace_exported_func_generated_no_op() {
let mut module = Module::default();

// Create original function
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
builder.func_body().i32_const(1234).drop();
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
})
.expect("export function replacement worked");

assert!(
module.exports.get_exported_func(original_fn_id).is_none(),
"replaced export function cannot be gotten by ID"
);

// Ensure the function was replaced
match module
.exports
.get_exported_func(new_fn_id)
.expect("failed to unwrap exported func")
{
exp @ Export {
item: ExportItem::Function(fid),
name,
..
} => {
assert_eq!(name, "dummy", "function name on export is unchanged");
assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
}
_ => panic!("expected an Export with a Function inside"),
}
}

/// Running `replace_imported_func` with a closure that builds
/// a function should replace the existing function with the new one
#[test]
fn replace_imported_func() {
let mut module = Module::default();

// Create original import function
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
})
.expect("import fn replacement worked");

assert!(
!module.imports.iter().any(|i| i.id() == original_import_id),
"original import is missing",
);

assert!(
module.imports.get_imported_func(original_fn_id).is_none(),
"replaced import function cannot be gotten by ID"
);

assert!(
module.imports.get_imported_func(new_fn_id).is_none(),
"new import function cannot be gotten by ID (it is now local)"
);

assert!(
matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
"new local function has the right kind"
);
}

/// Running `replace_imported_func` with a closure that returns None
/// should replace the function with a generated no-op function
#[test]
fn replace_imported_func_generated_no_op() {
let mut module = Module::default();

// Create original import function
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
})
.expect("import fn replacement worked");

assert!(
!module.imports.iter().any(|i| i.id() == original_import_id),
"original import is missing",
);

assert!(
module.imports.get_imported_func(original_fn_id).is_none(),
"replaced import function cannot be gotten by ID"
);

assert!(
module.imports.get_imported_func(new_fn_id).is_none(),
"new import function cannot be gotten by ID (it is now local)"
);

assert!(
matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
"new local function has the right kind"
);
}
}
Loading

0 comments on commit 4331cb2

Please sign in to comment.