Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve ergonomics of fn replace #250

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/tests/tests/spec-tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ fn run(wast: &Path) -> Result<(), anyhow::Error> {
let wasm = fs::read(&path)?;
let mut wasm = config
.parse(&wasm)
.context(format!("error parsing wasm (line {})", line))?;
.with_context(|| format!("error parsing wasm (line {})", line))?;
let wasm1 = wasm.emit_wasm();
fs::write(&path, &wasm1)?;
let wasm2 = config
.parse(&wasm1)
.map(|mut m| m.emit_wasm())
.context(format!("error re-parsing wasm (line {})", line))?;
.with_context(|| format!("error re-parsing wasm (line {})", line))?;
if wasm1 != wasm2 {
panic!("wasm module at line {} isn't deterministic", line);
}
Expand Down
14 changes: 14 additions & 0 deletions src/module/exports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ impl ModuleExports {
_ => false,
})
}

/// Delete an exported function by name from this module.
pub fn delete_func_by_name(&mut self, name: impl AsRef<str>) -> Result<()> {
let fid = self.get_func_by_name(name.as_ref()).context(format!(
"failed to find exported func with name [{}]",
name.as_ref()
))?;
self.delete(
self.get_exported_func(fid)
.with_context(|| format!("failed to find exported func with ID [{fid:?}]"))?
.id(),
);
Ok(())
}
}

impl Module {
Expand Down
196 changes: 100 additions & 96 deletions src/module/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::parse::IndicesToIds;
use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
use crate::ty::TypeId;
use crate::ty::ValType;
use crate::{ExportItem, Memory, MemoryId};
use crate::{ExportItem, FunctionBuilder, InstrSeqBuilder, LocalId, Memory, MemoryId};

pub use self::local_function::LocalFunction;

Expand Down Expand Up @@ -447,98 +447,118 @@ impl Module {

/// Replace a single exported function with the result of the provided builder function.
///
/// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
/// used to build the function as necessary.
///
/// For example, if you wanted to replace an exported function with a no-op,
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
///
/// module.replace_exported_func(fid, move || {
/// module.replace_exported_func(fid, |(body, arg_locals)| {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// The arguments passed to the original function will be passed to the
/// new exported function that was built in your closure.
///
/// This function returns the function ID of the *new* function,
/// after it has been inserted into the module as an export.
pub fn replace_exported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
match (self.exports.get_exported_func(fid), self.funcs.get(fid)) {
(
Some(exported_fn),
Function {
kind: FunctionKind::Local(lf),
..
},
) => {
// Retrieve the params & result types for the exported (local) function
let ty = self.types.get(lf.ty());
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Add the function produced by `fn_builder` as a local function,
let new_fid = self.funcs.add_local(
fn_builder((&params, &results)).context("export fn builder failed")?,
);

// Mutate the existing export to use the new local function
let export = self.exports.get_mut(exported_fn.id());
export.item = ExportItem::Function(new_fid);

Ok(new_fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an exported function"),
pub fn replace_exported_func(
&mut self,
fid: FunctionId,
builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
) -> Result<FunctionId> {
let original_export_id = self
.exports
.get_exported_func(fid)
.map(|e| e.id())
.with_context(|| format!("no exported function with ID [{fid:?}]"))?;

if let Function {
kind: FunctionKind::Local(lf),
..
} = self.funcs.get(fid)
{
// Retrieve the params & result types for the exported (local) function
let ty = self.types.get(lf.ty());
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Add the function produced by `fn_builder` as a local function
let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
let mut new_fn_body = builder.func_body();
builder_fn((&mut new_fn_body, &lf.args));
let func = builder.local_func(lf.args.clone());
guybedford marked this conversation as resolved.
Show resolved Hide resolved
let new_fn_id = self.funcs.add_local(func);

// Mutate the existing export to use the new local function
let export = self.exports.get_mut(original_export_id);
export.item = ExportItem::Function(new_fn_id);
Ok(new_fn_id)
} else {
bail!("cannot replace function [{fid:?}], it is not an exported function");
}
}

/// Replace a single imported function with the result of the provided builder function.
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
/// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
/// used to build the function as necessary.
///
/// module.replace_imported_func(fid, move || {
/// For example, if you wanted to replace an imported function with a no-op,
///
/// ```ignore
/// module.replace_imported_func(fid, |(body, arg_locals)| {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// The arguments passed to the original function will be passed to the
/// new exported function that was built in your closure.
///
/// This function returns the function ID of the *new* function, and
/// removes the existing import that has been replaced (the function will become local).
pub fn replace_imported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
// If the function is in the imports, replace it
match (self.imports.get_imported_func(fid), self.funcs.get(fid)) {
(
Some(original_imported_fn),
Function {
kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
..
},
) => {
// Retrieve the params & result types for the imported function
let ty = self.types.get(*tid);
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Mutate the existing function, changing it from a FunctionKind::ImportedFunction
// to the local function produced by running the provided `fn_builder`
let func = self.funcs.get_mut(fid);
func.kind = FunctionKind::Local(
fn_builder((&params, &results)).context("import fn builder failed")?,
);

self.imports.delete(original_imported_fn.id());

Ok(fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an imported function"),
pub fn replace_imported_func(
&mut self,
fid: FunctionId,
builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
) -> Result<FunctionId> {
let original_import_id = self
.imports
.get_imported_func(fid)
.map(|import| import.id())
.with_context(|| format!("no exported function with ID [{fid:?}]"))?;

if let Function {
kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
..
} = self.funcs.get(fid)
{
// Retrieve the params & result types for the imported function
let ty = self.types.get(*tid);
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Build the list LocalIds used by args to match the original function
let args = params
.iter()
.map(|ty| self.locals.add(*ty))
.collect::<Vec<_>>();

// Build the new function
let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
let mut new_fn_body = builder.func_body();
builder_fn((&mut new_fn_body, &args));
let new_func_kind = FunctionKind::Local(builder.local_func(args));

// Mutate the existing function, changing it from a FunctionKind::ImportedFunction
// to the local function produced by running the provided `fn_builder`
let func = self.funcs.get_mut(fid);
func.kind = new_func_kind;

self.imports.delete(original_import_id);

Ok(fid)
} else {
bail!("cannot replace function [{fid:?}], it is not an imported function");
}
}
}
Expand Down Expand Up @@ -683,14 +703,10 @@ mod tests {
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
.replace_exported_func(original_fn_id, |(body, _)| {
body.i32_const(4321).drop();
})
.expect("function replacement worked");

Expand Down Expand Up @@ -728,14 +744,10 @@ mod tests {
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
.replace_exported_func(original_fn_id, |(body, _arg_locals)| {
body.unreachable();
})
.expect("export function replacement worked");

Expand Down Expand Up @@ -773,14 +785,10 @@ mod tests {
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
.replace_imported_func(original_fn_id, |(body, _)| {
body.i32_const(4321).drop();
})
.expect("import fn replacement worked");

Expand Down Expand Up @@ -815,14 +823,10 @@ mod tests {
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
.replace_imported_func(original_fn_id, |(body, _arg_locals)| {
body.unreachable();
})
.expect("import fn replacement worked");

Expand Down
19 changes: 19 additions & 0 deletions src/module/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ impl ModuleImports {
_ => None,
})
}

/// Delete an imported function by name from this module.
pub fn delete_func_by_name(
&mut self,
module: impl AsRef<str>,
name: impl AsRef<str>,
) -> Result<()> {
let fid = self
.get_func_by_name(module, name.as_ref())
.with_context(|| {
format!("failed to find imported func with name [{}]", name.as_ref())
})?;
self.delete(
self.get_imported_func(fid)
.with_context(|| format!("failed to find imported func with ID [{fid:?}]"))?
.id(),
);
Ok(())
}
}

impl Module {
Expand Down
1 change: 1 addition & 0 deletions src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub use crate::module::debug::ModuleDebugData;
pub use crate::module::elements::ElementKind;
pub use crate::module::elements::{Element, ElementId, ModuleElements};
pub use crate::module::exports::{Export, ExportId, ExportItem, ModuleExports};
pub use crate::module::functions::{FuncParams, FuncResults};
pub use crate::module::functions::{Function, FunctionId, ModuleFunctions};
pub use crate::module::functions::{FunctionKind, ImportedFunction, LocalFunction};
pub use crate::module::globals::{Global, GlobalId, GlobalKind, ModuleGlobals};
Expand Down
Loading