Skip to content

Commit

Permalink
Merge pull request #28 from alexcrichton/parallel
Browse files Browse the repository at this point in the history
Parallel parsing and validation
  • Loading branch information
alexcrichton authored Jan 25, 2019
2 parents 637f146 + 71d9912 commit c5a2a5f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 63 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ version = "0.1.0"

[dependencies]
failure = "0.1.2"
id-arena = "2.0.1"
id-arena = { version = "2.1.0", features = ['rayon'] }
parity-wasm = "0.35.6"
petgraph = "0.4.13"
log = "0.4"
wasmparser = "0.24"
rayon = "1.0.3"

[dependencies.walrus-derive]
path = "./walrus-derive"
version = "=0.1.0"

[dev-dependencies]
env_logger = "0.6"

[workspace]
members = [
"./walrus-derive",
Expand Down
8 changes: 8 additions & 0 deletions examples/parse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// A small example which is primarily used to help benchmark parsing in walrus
// right now.

fn main() {
env_logger::init();
let a = std::env::args().nth(1).unwrap();
walrus::module::Module::from_file(a).unwrap();
}
49 changes: 6 additions & 43 deletions src/module/functions/local_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,43 +48,13 @@ impl LocalFunction {
/// Validates the given function body and constructs the `Expr` IR at the
/// same time.
pub fn parse(
module: &mut Module,
indices: &mut IndicesToIds,
module: &Module,
indices: &IndicesToIds,
id: FunctionId,
ty: TypeId,
body: wasmparser::FunctionBody,
args: Vec<LocalId>,
body: wasmparser::OperatorsReader,
) -> Result<LocalFunction> {
// First up, implicitly add locals for all function arguments. We also
// record these in the function itself for later processing.
let mut args = Vec::new();
for ty in module.types.get(ty).params() {
let local_id = module.locals.add(*ty);
indices.push_local(id, local_id);
args.push(local_id);
}

// WebAssembly local indices are 32 bits, so it's a validation error to
// have more than 2^32 locals. Sure enough there's a spec test for this!
let mut total = 0u32;
for local in body.get_locals_reader()? {
let (count, _) = local?;
total = match total.checked_add(count) {
Some(n) => n,
None => bail!("can't have more than 2^32 locals"),
};
}

// Now that we know we have a reasonable amount of locals, put them in
// our map.
for local in body.get_locals_reader()? {
let (count, ty) = local?;
let ty = ValType::parse(&ty)?;
for _ in 0..count {
let local_id = module.locals.add(ty);
indices.push_local(id, local_id);
}
}

let mut func = LocalFunction {
ty,
exprs: Arena::new(),
Expand All @@ -99,18 +69,11 @@ impl LocalFunction {
let operands = &mut context::OperandStack::new();
let controls = &mut context::ControlStack::new();

let mut ctx = FunctionContext::new(
module,
indices,
id,
&mut func,
operands,
controls,
);
let mut ctx = FunctionContext::new(module, indices, id, &mut func, operands, controls);

let entry = ctx.push_control(BlockKind::FunctionEntry, result.clone(), result);
ctx.func.entry = Some(entry);
validate_expression(&mut ctx, body.get_operators_reader()?)?;
validate_expression(&mut ctx, body)?;

debug_assert_eq!(ctx.operands.len(), result_len);
debug_assert!(ctx.controls.is_empty());
Expand Down
72 changes: 66 additions & 6 deletions src/module/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ use crate::module::imports::ImportId;
use crate::module::Module;
use crate::parse::IndicesToIds;
use crate::ty::TypeId;
use crate::ty::ValType;
use failure::bail;
use id_arena::{Arena, Id};
use parity_wasm::elements;
use rayon::prelude::*;
use std::cmp;
use std::fmt;

Expand Down Expand Up @@ -179,6 +181,11 @@ impl ModuleFunctions {
pub fn iter(&self) -> impl Iterator<Item = &Function> {
self.arena.iter().map(|(_, f)| f)
}

/// Get a shared reference to this module's functions.
pub fn par_iter(&self) -> impl ParallelIterator<Item = &Function> {
self.arena.par_iter().map(|(_, f)| f)
}
}

impl Module {
Expand Down Expand Up @@ -212,6 +219,11 @@ impl Module {
}
let num_imports = self.funcs.arena.len() - (amt as usize);

// First up serially create corresponding `LocalId` instances for all
// functions as well as extract the operators parser for each function.
// This is pretty tough to parallelize, but we can look into it later if
// necessary and it's a bottleneck!
let mut bodies = Vec::with_capacity(amt as usize);
for (i, body) in section.into_iter().enumerate() {
let body = body?;
let index = (num_imports + i) as u32;
Expand All @@ -221,12 +233,60 @@ impl Module {
_ => unreachable!(),
};

let local = LocalFunction::parse(self, indices, id, ty, body)?;
self.funcs.arena[id] = Function {
id,
kind: FunctionKind::Local(local),
name: None,
};
// First up, implicitly add locals for all function arguments. We also
// record these in the function itself for later processing.
let mut args = Vec::new();
for ty in self.types.get(ty).params() {
let local_id = self.locals.add(*ty);
indices.push_local(id, local_id);
args.push(local_id);
}

// WebAssembly local indices are 32 bits, so it's a validation error to
// have more than 2^32 locals. Sure enough there's a spec test for this!
let mut total = 0u32;
for local in body.get_locals_reader()? {
let (count, _) = local?;
total = match total.checked_add(count) {
Some(n) => n,
None => bail!("can't have more than 2^32 locals"),
};
}

// Now that we know we have a reasonable amount of locals, put them in
// our map.
for local in body.get_locals_reader()? {
let (count, ty) = local?;
let ty = ValType::parse(&ty)?;
for _ in 0..count {
let local_id = self.locals.add(ty);
indices.push_local(id, local_id);
}
}

let body = body.get_operators_reader()?;
bodies.push((id, body, args, ty));
}

// Wasm modules can often have a lot of functions and this operation can
// take some time, so parse all function bodies in parallel.
let results = bodies
.into_par_iter()
.map(|(id, body, args, ty)| {
LocalFunction::parse(self, indices, id, ty, args, body).map(|local| Function {
id,
kind: FunctionKind::Local(local),
name: None,
})
})
.collect::<Vec<_>>();

// After all the function bodies are collected and finished push them
// into our function arena.
for func in results {
let func = func?;
let id = func.id;
self.funcs.arena[id] = func;
}

Ok(())
Expand Down
14 changes: 14 additions & 0 deletions src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,56 +81,67 @@ impl Module {
let section = parser.read()?;
match section.code {
wasmparser::SectionCode::Data => {
log::debug!("parsing data section");
let reader = section.get_data_section_reader()?;
ret.parse_data(reader, &mut indices)
.context("failed to parse data section")?;
}
wasmparser::SectionCode::Type => {
log::debug!("parsing type section");
let reader = section.get_type_section_reader()?;
ret.parse_types(reader, &mut indices)
.context("failed to parse type section")?;
}
wasmparser::SectionCode::Import => {
log::debug!("parsing import section");
let reader = section.get_import_section_reader()?;
ret.parse_imports(reader, &mut indices)
.context("failed to parse import section")?;
}
wasmparser::SectionCode::Table => {
log::debug!("parsing table section");
let reader = section.get_table_section_reader()?;
ret.parse_tables(reader, &mut indices)
.context("failed to parse table section")?;
}
wasmparser::SectionCode::Memory => {
log::debug!("parsing memory section");
let reader = section.get_memory_section_reader()?;
ret.parse_memories(reader, &mut indices)
.context("failed to parse memory section")?;
}
wasmparser::SectionCode::Global => {
log::debug!("parsing global section");
let reader = section.get_global_section_reader()?;
ret.parse_globals(reader, &mut indices)
.context("failed to parse global section")?;
}
wasmparser::SectionCode::Export => {
log::debug!("parsing export section");
let reader = section.get_export_section_reader()?;
ret.parse_exports(reader, &mut indices)
.context("failed to parse export section")?;
}
wasmparser::SectionCode::Element => {
log::debug!("parsing element section");
let reader = section.get_element_section_reader()?;
ret.parse_elements(reader, &mut indices)
.context("failed to parse element section")?;
}
wasmparser::SectionCode::Start => {
log::debug!("parsing start section");
let idx = section.get_start_section_content()?;
ret.start = Some(indices.get_func(idx)?);
}
wasmparser::SectionCode::Function => {
log::debug!("parsing function section");
let reader = section.get_function_section_reader()?;
function_section_size = Some(reader.get_count());
ret.declare_local_functions(reader, &mut indices)
.context("failed to parse function section")?;
}
wasmparser::SectionCode::Code => {
log::debug!("parsing code section");
let function_section_size = match function_section_size.take() {
Some(i) => i,
None => bail!("cannot have a code section without function section"),
Expand All @@ -140,6 +151,7 @@ impl Module {
.context("failed to parse code section")?;
}
wasmparser::SectionCode::Custom { name, kind: _ } => {
log::debug!("parsing custom section `{}`", name);
let result = match name {
"producers" => {
let reader = section.get_binary_reader();
Expand Down Expand Up @@ -175,8 +187,10 @@ impl Module {
.add_processed_by("walrus", env!("CARGO_PKG_VERSION"));

// TODO: probably run this in a different location
log::debug!("validating module");
crate::passes::validate::run(&ret)?;

log::debug!("parse complete");
Ok(ret)
}

Expand Down
35 changes: 22 additions & 13 deletions src/passes/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::module::tables::{Table, TableKind};
use crate::module::Module;
use crate::ty::ValType;
use failure::{bail, ResultExt};
use rayon::prelude::*;
use std::collections::HashSet;

/// Validate a wasm module, returning an error if it fails to validate.
Expand Down Expand Up @@ -48,19 +49,27 @@ pub fn run(module: &Module) -> Result<()> {

// Validate each function in the module, collecting errors and returning
// them all at once if there are any.
let mut errs = Vec::new();
for function in module.funcs.iter() {
let local = match &function.kind {
FunctionKind::Local(local) => local,
_ => continue,
};
let mut cx = Validate {
errs: &mut errs,
function,
local,
};
local.entry_block().visit(&mut cx);
}
let errs = module
.funcs
.par_iter()
.map(|function| {
let mut errs = Vec::new();
let local = match &function.kind {
FunctionKind::Local(local) => local,
_ => return Vec::new(),
};
let mut cx = Validate {
errs: &mut errs,
function,
local,
};
local.entry_block().visit(&mut cx);
errs
})
.reduce(Vec::new, |mut a, b| {
a.extend(b);
a
});
if errs.len() == 0 {
return Ok(());
}
Expand Down

0 comments on commit c5a2a5f

Please sign in to comment.