diff --git a/Cargo.toml b/Cargo.toml index 921291df..177c8376 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,16 +6,20 @@ version = "0.1.0" [dependencies] failure = "0.1.2" -id-arena = "2.0.1" +id-arena = { version = "2.1.0", features = ['rayon'] } parity-wasm = "0.35.6" petgraph = "0.4.13" log = "0.4" wasmparser = "0.24" +rayon = "1.0.3" [dependencies.walrus-derive] path = "./walrus-derive" version = "=0.1.0" +[dev-dependencies] +env_logger = "0.6" + [workspace] members = [ "./walrus-derive", diff --git a/examples/parse.rs b/examples/parse.rs new file mode 100644 index 00000000..b259ada2 --- /dev/null +++ b/examples/parse.rs @@ -0,0 +1,8 @@ +// A small example which is primarily used to help benchmark parsing in walrus +// right now. + +fn main() { + env_logger::init(); + let a = std::env::args().nth(1).unwrap(); + walrus::module::Module::from_file(a).unwrap(); +} diff --git a/src/module/functions/local_function/mod.rs b/src/module/functions/local_function/mod.rs index d207d425..7c80bf18 100644 --- a/src/module/functions/local_function/mod.rs +++ b/src/module/functions/local_function/mod.rs @@ -48,43 +48,13 @@ impl LocalFunction { /// Validates the given function body and constructs the `Expr` IR at the /// same time. pub fn parse( - module: &mut Module, - indices: &mut IndicesToIds, + module: &Module, + indices: &IndicesToIds, id: FunctionId, ty: TypeId, - body: wasmparser::FunctionBody, + args: Vec, + body: wasmparser::OperatorsReader, ) -> Result { - // First up, implicitly add locals for all function arguments. We also - // record these in the function itself for later processing. - let mut args = Vec::new(); - for ty in module.types.get(ty).params() { - let local_id = module.locals.add(*ty); - indices.push_local(id, local_id); - args.push(local_id); - } - - // WebAssembly local indices are 32 bits, so it's a validation error to - // have more than 2^32 locals. Sure enough there's a spec test for this! - let mut total = 0u32; - for local in body.get_locals_reader()? { - let (count, _) = local?; - total = match total.checked_add(count) { - Some(n) => n, - None => bail!("can't have more than 2^32 locals"), - }; - } - - // Now that we know we have a reasonable amount of locals, put them in - // our map. - for local in body.get_locals_reader()? { - let (count, ty) = local?; - let ty = ValType::parse(&ty)?; - for _ in 0..count { - let local_id = module.locals.add(ty); - indices.push_local(id, local_id); - } - } - let mut func = LocalFunction { ty, exprs: Arena::new(), @@ -99,18 +69,11 @@ impl LocalFunction { let operands = &mut context::OperandStack::new(); let controls = &mut context::ControlStack::new(); - let mut ctx = FunctionContext::new( - module, - indices, - id, - &mut func, - operands, - controls, - ); + let mut ctx = FunctionContext::new(module, indices, id, &mut func, operands, controls); let entry = ctx.push_control(BlockKind::FunctionEntry, result.clone(), result); ctx.func.entry = Some(entry); - validate_expression(&mut ctx, body.get_operators_reader()?)?; + validate_expression(&mut ctx, body)?; debug_assert_eq!(ctx.operands.len(), result_len); debug_assert!(ctx.controls.is_empty()); diff --git a/src/module/functions/mod.rs b/src/module/functions/mod.rs index b8131deb..97ce3bc8 100644 --- a/src/module/functions/mod.rs +++ b/src/module/functions/mod.rs @@ -9,9 +9,11 @@ use crate::module::imports::ImportId; use crate::module::Module; use crate::parse::IndicesToIds; use crate::ty::TypeId; +use crate::ty::ValType; use failure::bail; use id_arena::{Arena, Id}; use parity_wasm::elements; +use rayon::prelude::*; use std::cmp; use std::fmt; @@ -179,6 +181,11 @@ impl ModuleFunctions { pub fn iter(&self) -> impl Iterator { self.arena.iter().map(|(_, f)| f) } + + /// Get a shared reference to this module's functions. + pub fn par_iter(&self) -> impl ParallelIterator { + self.arena.par_iter().map(|(_, f)| f) + } } impl Module { @@ -212,6 +219,11 @@ impl Module { } let num_imports = self.funcs.arena.len() - (amt as usize); + // First up serially create corresponding `LocalId` instances for all + // functions as well as extract the operators parser for each function. + // This is pretty tough to parallelize, but we can look into it later if + // necessary and it's a bottleneck! + let mut bodies = Vec::with_capacity(amt as usize); for (i, body) in section.into_iter().enumerate() { let body = body?; let index = (num_imports + i) as u32; @@ -221,12 +233,60 @@ impl Module { _ => unreachable!(), }; - let local = LocalFunction::parse(self, indices, id, ty, body)?; - self.funcs.arena[id] = Function { - id, - kind: FunctionKind::Local(local), - name: None, - }; + // First up, implicitly add locals for all function arguments. We also + // record these in the function itself for later processing. + let mut args = Vec::new(); + for ty in self.types.get(ty).params() { + let local_id = self.locals.add(*ty); + indices.push_local(id, local_id); + args.push(local_id); + } + + // WebAssembly local indices are 32 bits, so it's a validation error to + // have more than 2^32 locals. Sure enough there's a spec test for this! + let mut total = 0u32; + for local in body.get_locals_reader()? { + let (count, _) = local?; + total = match total.checked_add(count) { + Some(n) => n, + None => bail!("can't have more than 2^32 locals"), + }; + } + + // Now that we know we have a reasonable amount of locals, put them in + // our map. + for local in body.get_locals_reader()? { + let (count, ty) = local?; + let ty = ValType::parse(&ty)?; + for _ in 0..count { + let local_id = self.locals.add(ty); + indices.push_local(id, local_id); + } + } + + let body = body.get_operators_reader()?; + bodies.push((id, body, args, ty)); + } + + // Wasm modules can often have a lot of functions and this operation can + // take some time, so parse all function bodies in parallel. + let results = bodies + .into_par_iter() + .map(|(id, body, args, ty)| { + LocalFunction::parse(self, indices, id, ty, args, body).map(|local| Function { + id, + kind: FunctionKind::Local(local), + name: None, + }) + }) + .collect::>(); + + // After all the function bodies are collected and finished push them + // into our function arena. + for func in results { + let func = func?; + let id = func.id; + self.funcs.arena[id] = func; } Ok(()) diff --git a/src/module/mod.rs b/src/module/mod.rs index 5b58506a..33bfce3e 100644 --- a/src/module/mod.rs +++ b/src/module/mod.rs @@ -81,56 +81,67 @@ impl Module { let section = parser.read()?; match section.code { wasmparser::SectionCode::Data => { + log::debug!("parsing data section"); let reader = section.get_data_section_reader()?; ret.parse_data(reader, &mut indices) .context("failed to parse data section")?; } wasmparser::SectionCode::Type => { + log::debug!("parsing type section"); let reader = section.get_type_section_reader()?; ret.parse_types(reader, &mut indices) .context("failed to parse type section")?; } wasmparser::SectionCode::Import => { + log::debug!("parsing import section"); let reader = section.get_import_section_reader()?; ret.parse_imports(reader, &mut indices) .context("failed to parse import section")?; } wasmparser::SectionCode::Table => { + log::debug!("parsing table section"); let reader = section.get_table_section_reader()?; ret.parse_tables(reader, &mut indices) .context("failed to parse table section")?; } wasmparser::SectionCode::Memory => { + log::debug!("parsing memory section"); let reader = section.get_memory_section_reader()?; ret.parse_memories(reader, &mut indices) .context("failed to parse memory section")?; } wasmparser::SectionCode::Global => { + log::debug!("parsing global section"); let reader = section.get_global_section_reader()?; ret.parse_globals(reader, &mut indices) .context("failed to parse global section")?; } wasmparser::SectionCode::Export => { + log::debug!("parsing export section"); let reader = section.get_export_section_reader()?; ret.parse_exports(reader, &mut indices) .context("failed to parse export section")?; } wasmparser::SectionCode::Element => { + log::debug!("parsing element section"); let reader = section.get_element_section_reader()?; ret.parse_elements(reader, &mut indices) .context("failed to parse element section")?; } wasmparser::SectionCode::Start => { + log::debug!("parsing start section"); let idx = section.get_start_section_content()?; ret.start = Some(indices.get_func(idx)?); } wasmparser::SectionCode::Function => { + log::debug!("parsing function section"); let reader = section.get_function_section_reader()?; function_section_size = Some(reader.get_count()); ret.declare_local_functions(reader, &mut indices) .context("failed to parse function section")?; } wasmparser::SectionCode::Code => { + log::debug!("parsing code section"); let function_section_size = match function_section_size.take() { Some(i) => i, None => bail!("cannot have a code section without function section"), @@ -140,6 +151,7 @@ impl Module { .context("failed to parse code section")?; } wasmparser::SectionCode::Custom { name, kind: _ } => { + log::debug!("parsing custom section `{}`", name); let result = match name { "producers" => { let reader = section.get_binary_reader(); @@ -175,8 +187,10 @@ impl Module { .add_processed_by("walrus", env!("CARGO_PKG_VERSION")); // TODO: probably run this in a different location + log::debug!("validating module"); crate::passes::validate::run(&ret)?; + log::debug!("parse complete"); Ok(ret) } diff --git a/src/passes/validate.rs b/src/passes/validate.rs index c0bce564..6446a43b 100644 --- a/src/passes/validate.rs +++ b/src/passes/validate.rs @@ -13,6 +13,7 @@ use crate::module::tables::{Table, TableKind}; use crate::module::Module; use crate::ty::ValType; use failure::{bail, ResultExt}; +use rayon::prelude::*; use std::collections::HashSet; /// Validate a wasm module, returning an error if it fails to validate. @@ -48,19 +49,27 @@ pub fn run(module: &Module) -> Result<()> { // Validate each function in the module, collecting errors and returning // them all at once if there are any. - let mut errs = Vec::new(); - for function in module.funcs.iter() { - let local = match &function.kind { - FunctionKind::Local(local) => local, - _ => continue, - }; - let mut cx = Validate { - errs: &mut errs, - function, - local, - }; - local.entry_block().visit(&mut cx); - } + let errs = module + .funcs + .par_iter() + .map(|function| { + let mut errs = Vec::new(); + let local = match &function.kind { + FunctionKind::Local(local) => local, + _ => return Vec::new(), + }; + let mut cx = Validate { + errs: &mut errs, + function, + local, + }; + local.entry_block().visit(&mut cx); + errs + }) + .reduce(Vec::new, |mut a, b| { + a.extend(b); + a + }); if errs.len() == 0 { return Ok(()); }