From 45ca488ae15717a67912af3a070141995e588490 Mon Sep 17 00:00:00 2001 From: Victor Adossi <123968127+vados-cosmonic@users.noreply.github.com> Date: Thu, 19 Oct 2023 05:56:51 +0900 Subject: [PATCH] refactor: improve ergonomics of fn replace (#250) After trialing `replace_(imported|exported)_func` in WASI-Virt, it's clear that the ergonomics around the builder function need to be improved. `FunctionBuilder` (particularly `FunctionBuilder::new()` is difficult to use without a mutable borrow of the module itself. This commit refactors `replace_(imported|exported)_func` in order to pass through the mutable borrow which makes it easier to use `FunctionBuilder`s. Signed-off-by: Victor Adossi --- crates/tests/tests/spec-tests.rs | 4 +- src/module/exports.rs | 14 +++ src/module/functions/mod.rs | 196 ++++++++++++++++--------------- src/module/imports.rs | 19 +++ src/module/mod.rs | 1 + 5 files changed, 136 insertions(+), 98 deletions(-) diff --git a/crates/tests/tests/spec-tests.rs b/crates/tests/tests/spec-tests.rs index d352bc05..1b1b9e28 100644 --- a/crates/tests/tests/spec-tests.rs +++ b/crates/tests/tests/spec-tests.rs @@ -124,13 +124,13 @@ fn run(wast: &Path) -> Result<(), anyhow::Error> { let wasm = fs::read(&path)?; let mut wasm = config .parse(&wasm) - .context(format!("error parsing wasm (line {})", line))?; + .with_context(|| format!("error parsing wasm (line {})", line))?; let wasm1 = wasm.emit_wasm(); fs::write(&path, &wasm1)?; let wasm2 = config .parse(&wasm1) .map(|mut m| m.emit_wasm()) - .context(format!("error re-parsing wasm (line {})", line))?; + .with_context(|| format!("error re-parsing wasm (line {})", line))?; if wasm1 != wasm2 { panic!("wasm module at line {} isn't deterministic", line); } diff --git a/src/module/exports.rs b/src/module/exports.rs index 682f170c..3a259339 100644 --- a/src/module/exports.rs +++ b/src/module/exports.rs @@ -135,6 +135,20 @@ impl ModuleExports { _ => false, }) } + + /// Delete an exported function by name from this module. + pub fn delete_func_by_name(&mut self, name: impl AsRef) -> Result<()> { + let fid = self.get_func_by_name(name.as_ref()).context(format!( + "failed to find exported func with name [{}]", + name.as_ref() + ))?; + self.delete( + self.get_exported_func(fid) + .with_context(|| format!("failed to find exported func with ID [{fid:?}]"))? + .id(), + ); + Ok(()) + } } impl Module { diff --git a/src/module/functions/mod.rs b/src/module/functions/mod.rs index dca3bb49..82741fbb 100644 --- a/src/module/functions/mod.rs +++ b/src/module/functions/mod.rs @@ -21,7 +21,7 @@ use crate::parse::IndicesToIds; use crate::tombstone_arena::{Id, Tombstone, TombstoneArena}; use crate::ty::TypeId; use crate::ty::ValType; -use crate::{ExportItem, Memory, MemoryId}; +use crate::{ExportItem, FunctionBuilder, InstrSeqBuilder, LocalId, Memory, MemoryId}; pub use self::local_function::LocalFunction; @@ -447,98 +447,118 @@ impl Module { /// Replace a single exported function with the result of the provided builder function. /// + /// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be + /// used to build the function as necessary. + /// /// For example, if you wanted to replace an exported function with a no-op, /// /// ```ignore - /// // Since `FunctionBuilder` requires a mutable pointer to the module's types, - /// // we must build it *outside* the closure and `move` it in - /// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); - /// - /// module.replace_exported_func(fid, move || { + /// module.replace_exported_func(fid, |(body, arg_locals)| { /// builder.func_body().unreachable(); - /// builder.local_func(vec![]) /// }); /// ``` /// + /// The arguments passed to the original function will be passed to the + /// new exported function that was built in your closure. + /// /// This function returns the function ID of the *new* function, /// after it has been inserted into the module as an export. - pub fn replace_exported_func(&mut self, fid: FunctionId, fn_builder: F) -> Result - where - F: FnOnce((&FuncParams, &FuncResults)) -> Result, - { - match (self.exports.get_exported_func(fid), self.funcs.get(fid)) { - ( - Some(exported_fn), - Function { - kind: FunctionKind::Local(lf), - .. - }, - ) => { - // Retrieve the params & result types for the exported (local) function - let ty = self.types.get(lf.ty()); - let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); - - // Add the function produced by `fn_builder` as a local function, - let new_fid = self.funcs.add_local( - fn_builder((¶ms, &results)).context("export fn builder failed")?, - ); - - // Mutate the existing export to use the new local function - let export = self.exports.get_mut(exported_fn.id()); - export.item = ExportItem::Function(new_fid); - - Ok(new_fid) - } - // The export didn't exist, or the function isn't the kind we expect - _ => bail!("cannot replace function [{fid:?}], it is not an exported function"), + pub fn replace_exported_func( + &mut self, + fid: FunctionId, + builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec)), + ) -> Result { + let original_export_id = self + .exports + .get_exported_func(fid) + .map(|e| e.id()) + .with_context(|| format!("no exported function with ID [{fid:?}]"))?; + + if let Function { + kind: FunctionKind::Local(lf), + .. + } = self.funcs.get(fid) + { + // Retrieve the params & result types for the exported (local) function + let ty = self.types.get(lf.ty()); + let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); + + // Add the function produced by `fn_builder` as a local function + let mut builder = FunctionBuilder::new(&mut self.types, ¶ms, &results); + let mut new_fn_body = builder.func_body(); + builder_fn((&mut new_fn_body, &lf.args)); + let func = builder.local_func(lf.args.clone()); + let new_fn_id = self.funcs.add_local(func); + + // Mutate the existing export to use the new local function + let export = self.exports.get_mut(original_export_id); + export.item = ExportItem::Function(new_fn_id); + Ok(new_fn_id) + } else { + bail!("cannot replace function [{fid:?}], it is not an exported function"); } } /// Replace a single imported function with the result of the provided builder function. /// - /// ```ignore - /// // Since `FunctionBuilder` requires a mutable pointer to the module's types, - /// // we must build it *outside* the closure and `move` it in - /// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); + /// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be + /// used to build the function as necessary. /// - /// module.replace_imported_func(fid, move || { + /// For example, if you wanted to replace an imported function with a no-op, + /// + /// ```ignore + /// module.replace_imported_func(fid, |(body, arg_locals)| { /// builder.func_body().unreachable(); - /// builder.local_func(vec![]) /// }); /// ``` /// + /// The arguments passed to the original function will be passed to the + /// new exported function that was built in your closure. + /// /// This function returns the function ID of the *new* function, and /// removes the existing import that has been replaced (the function will become local). - pub fn replace_imported_func(&mut self, fid: FunctionId, fn_builder: F) -> Result - where - F: FnOnce((&FuncParams, &FuncResults)) -> Result, - { - // If the function is in the imports, replace it - match (self.imports.get_imported_func(fid), self.funcs.get(fid)) { - ( - Some(original_imported_fn), - Function { - kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }), - .. - }, - ) => { - // Retrieve the params & result types for the imported function - let ty = self.types.get(*tid); - let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); - - // Mutate the existing function, changing it from a FunctionKind::ImportedFunction - // to the local function produced by running the provided `fn_builder` - let func = self.funcs.get_mut(fid); - func.kind = FunctionKind::Local( - fn_builder((¶ms, &results)).context("import fn builder failed")?, - ); - - self.imports.delete(original_imported_fn.id()); - - Ok(fid) - } - // The export didn't exist, or the function isn't the kind we expect - _ => bail!("cannot replace function [{fid:?}], it is not an imported function"), + pub fn replace_imported_func( + &mut self, + fid: FunctionId, + builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec)), + ) -> Result { + let original_import_id = self + .imports + .get_imported_func(fid) + .map(|import| import.id()) + .with_context(|| format!("no exported function with ID [{fid:?}]"))?; + + if let Function { + kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }), + .. + } = self.funcs.get(fid) + { + // Retrieve the params & result types for the imported function + let ty = self.types.get(*tid); + let (params, results) = (ty.params().to_vec(), ty.results().to_vec()); + + // Build the list LocalIds used by args to match the original function + let args = params + .iter() + .map(|ty| self.locals.add(*ty)) + .collect::>(); + + // Build the new function + let mut builder = FunctionBuilder::new(&mut self.types, ¶ms, &results); + let mut new_fn_body = builder.func_body(); + builder_fn((&mut new_fn_body, &args)); + let new_func_kind = FunctionKind::Local(builder.local_func(args)); + + // Mutate the existing function, changing it from a FunctionKind::ImportedFunction + // to the local function produced by running the provided `fn_builder` + let func = self.funcs.get_mut(fid); + func.kind = new_func_kind; + + self.imports.delete(original_import_id); + + Ok(fid) + } else { + bail!("cannot replace function [{fid:?}], it is not an imported function"); } } } @@ -683,14 +703,10 @@ mod tests { let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); let original_export_id = module.exports.add("dummy", original_fn_id); - // Create builder to use inside closure - let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); - // Replace the existing function with a new one with a reversed const value let new_fn_id = module - .replace_exported_func(original_fn_id, move |_| { - builder.func_body().i32_const(4321).drop(); - Ok(builder.local_func(vec![])) + .replace_exported_func(original_fn_id, |(body, _)| { + body.i32_const(4321).drop(); }) .expect("function replacement worked"); @@ -728,14 +744,10 @@ mod tests { let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs); let original_export_id = module.exports.add("dummy", original_fn_id); - // Create builder to use inside closure - let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); - // Replace the existing function with a new one with a reversed const value let new_fn_id = module - .replace_exported_func(original_fn_id, move |_| { - builder.func_body().unreachable(); - Ok(builder.local_func(vec![])) + .replace_exported_func(original_fn_id, |(body, _arg_locals)| { + body.unreachable(); }) .expect("export function replacement worked"); @@ -773,14 +785,10 @@ mod tests { let types = module.types.add(&[], &[]); let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types); - // Create builder to use inside closure - let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); - // Replace the existing function with a new one with a reversed const value let new_fn_id = module - .replace_imported_func(original_fn_id, |_| { - builder.func_body().i32_const(4321).drop(); - Ok(builder.local_func(vec![])) + .replace_imported_func(original_fn_id, |(body, _)| { + body.i32_const(4321).drop(); }) .expect("import fn replacement worked"); @@ -815,14 +823,10 @@ mod tests { let types = module.types.add(&[], &[]); let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types); - // Create builder to use inside closure - let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]); - // Replace the existing function with a new one with a reversed const value let new_fn_id = module - .replace_imported_func(original_fn_id, |_| { - builder.func_body().unreachable(); - Ok(builder.local_func(vec![])) + .replace_imported_func(original_fn_id, |(body, _arg_locals)| { + body.unreachable(); }) .expect("import fn replacement worked"); diff --git a/src/module/imports.rs b/src/module/imports.rs index 7e89821e..d74bc453 100644 --- a/src/module/imports.rs +++ b/src/module/imports.rs @@ -130,6 +130,25 @@ impl ModuleImports { _ => None, }) } + + /// Delete an imported function by name from this module. + pub fn delete_func_by_name( + &mut self, + module: impl AsRef, + name: impl AsRef, + ) -> Result<()> { + let fid = self + .get_func_by_name(module, name.as_ref()) + .with_context(|| { + format!("failed to find imported func with name [{}]", name.as_ref()) + })?; + self.delete( + self.get_imported_func(fid) + .with_context(|| format!("failed to find imported func with ID [{fid:?}]"))? + .id(), + ); + Ok(()) + } } impl Module { diff --git a/src/module/mod.rs b/src/module/mod.rs index 4703387d..74a8b699 100644 --- a/src/module/mod.rs +++ b/src/module/mod.rs @@ -27,6 +27,7 @@ pub use crate::module::debug::ModuleDebugData; pub use crate::module::elements::ElementKind; pub use crate::module::elements::{Element, ElementId, ModuleElements}; pub use crate::module::exports::{Export, ExportId, ExportItem, ModuleExports}; +pub use crate::module::functions::{FuncParams, FuncResults}; pub use crate::module::functions::{Function, FunctionId, ModuleFunctions}; pub use crate::module::functions::{FunctionKind, ImportedFunction, LocalFunction}; pub use crate::module::globals::{Global, GlobalId, GlobalKind, ModuleGlobals};