Skip to content

Commit

Permalink
Add benchmark for Wasm invocations (#4664)
Browse files Browse the repository at this point in the history
Only create a single linker, and reuse that to create all subsequent instances.

Add benchmarks to measure it.

Before:

```
test wasm::tests::bench_invoke ... bench:     246,826 ns/iter (+/- 49,215)
```

After:

```
test wasm::tests::bench_invoke ... bench:      45,246 ns/iter (+/- 1,894)
```

Ref #3757
  • Loading branch information
tiziano88 authored Jan 18, 2024
1 parent ac6709f commit a91172d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 29 deletions.
2 changes: 2 additions & 0 deletions oak_functions_service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![feature(never_type)]
#![feature(unwrap_infallible)]
// Required for enabling benchmark tests.
#![feature(test)]

extern crate alloc;

Expand Down
37 changes: 13 additions & 24 deletions oak_functions_service/src/wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use log::Level;
use micro_rpc::StatusCode;
use oak_functions_abi::{Request, Response};
use spinning_top::Spinlock;
use wasmi::{MemoryType, Store};
use wasmi::Store;

use crate::{
logger::{OakLogger, StandaloneLogger},
Expand Down Expand Up @@ -119,20 +119,9 @@ impl<L> OakLinker<L>
where
L: OakLogger,
{
fn new(engine: &wasmi::Engine, store: &mut Store<UserState<L>>) -> Self {
fn new(engine: &wasmi::Engine) -> Self {
let mut linker: wasmi::Linker<UserState<L>> = wasmi::Linker::new(engine);

// Add memory to linker.
// TODO(#3783): Find a sensible value for initial pages.
let initial_pages = 100;
let memory_type =
MemoryType::new(initial_pages, None).expect("failed to define Wasm memory type");
let memory =
wasmi::Memory::new(store, memory_type).expect("failed to initialize Wasm memory");
linker
.define(OAK_FUNCTIONS, MEMORY_NAME, wasmi::Extern::Memory(memory))
.expect("failed to define Wasm memory in linker");

linker
.func_wrap(
OAK_FUNCTIONS,
Expand Down Expand Up @@ -218,13 +207,11 @@ where

/// Instantiates the Oak Linker and checks whether the instance exports `main`, `alloc` and a
/// memory is attached.
///
/// Use the same store used when creating the linker.
fn instantiate(
self,
mut store: Store<UserState<L>>,
&self,
mut store: &mut Store<UserState<L>>,
module: Arc<wasmi::Module>,
) -> Result<(wasmi::Instance, Store<UserState<L>>), micro_rpc::Status> {
) -> Result<wasmi::Instance, micro_rpc::Status> {
let instance = self
.linker
.instantiate(&mut store, &module)
Expand All @@ -245,7 +232,7 @@ where

// Check that the instance exports "main".
let _ = &instance
.get_typed_func::<(), ()>(&store, MAIN_FUNCTION_NAME)
.get_typed_func::<(), ()>(&mut store, MAIN_FUNCTION_NAME)
.map_err(|err| {
micro_rpc::Status::new_with_message(
micro_rpc::StatusCode::Internal,
Expand All @@ -255,7 +242,7 @@ where

// Check that the instance exports "alloc".
let _ = &instance
.get_typed_func::<i32, AbiPointer>(&store, ALLOC_FUNCTION_NAME)
.get_typed_func::<i32, AbiPointer>(&mut store, ALLOC_FUNCTION_NAME)
.map_err(|err| {
micro_rpc::Status::new_with_message(
micro_rpc::StatusCode::Internal,
Expand All @@ -272,7 +259,7 @@ where
)
})?;

Ok((instance, store))
Ok(instance)
}
}

Expand Down Expand Up @@ -423,9 +410,9 @@ where
}

// A request handler with a Wasm module for handling multiple requests.
#[derive(Clone)]
pub struct WasmHandler<L: OakLogger> {
wasm_module: Arc<wasmi::Module>,
linker: OakLinker<L>,
wasm_api_factory: Arc<dyn WasmApiFactory<L> + Send + Sync>,
logger: L,
#[cfg_attr(not(feature = "std"), allow(dead_code))]
Expand Down Expand Up @@ -466,8 +453,11 @@ where
let module = wasmi::Module::new(&engine, wasm_module_bytes)
.map_err(|err| anyhow::anyhow!("couldn't load module from buffer: {:?}", err))?;

let linker = OakLinker::new(module.engine());

Ok(WasmHandler {
wasm_module: Arc::new(module),
linker,
wasm_api_factory,
logger,
observer,
Expand All @@ -489,8 +479,7 @@ where
let user_state = UserState::new(wasm_api.transport(), self.logger.clone());
// For isolated requests we need to create a new store for every request.
let mut store = wasmi::Store::new(module.engine(), user_state);
let linker = OakLinker::new(module.engine(), &mut store);
let (instance, mut store) = linker.instantiate(store, module)?;
let instance = self.linker.instantiate(&mut store, module)?;

instance.exports(&store).for_each(|export| {
store
Expand Down
64 changes: 59 additions & 5 deletions oak_functions_service/src/wasm/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
// limitations under the License.
//

extern crate test;

use alloc::{sync::Arc, vec::Vec};
use std::time::Duration;

use byteorder::{ByteOrder, LittleEndian};
use hashbrown::HashMap;
use oak_functions_abi::Request;
use spinning_top::Spinlock;
use test::Bencher;

use super::{
api::StdWasmApiFactory, OakLinker, UserState, WasmApiFactory, WasmHandler, ALLOC_FUNCTION_NAME,
Expand Down Expand Up @@ -125,9 +130,57 @@ fn test_read_request() {
assert_eq!(request_bytes, test_state.request.clone());
}

#[test]
fn test_invoke() {
let test_state = create_test_state();
let data = b"Hello, world!";
let response = test_state
.wasm_handler
.handle_invoke(Request {
body: data.to_vec(),
})
.unwrap();
assert_eq!(response.body, data.to_vec());
}

#[bench]
fn bench_invoke(bencher: &mut Bencher) {
let test_state = create_test_state();
let data = b"Hello, world!";

let summary = bencher.bench(|bencher| {
bencher.iter(|| {
let response = test_state
.wasm_handler
.handle_invoke(Request {
body: data.to_vec(),
})
.unwrap();
assert_eq!(response.body, data.to_vec());
});
Ok(())
});

// When running `cargo test` this benchmark test gets executed too, but `summary` will be `None`
// in that case. So, here we first check that `summary` is not empty.
if let Ok(Some(summary)) = summary {
// `summary.mean` is in nanoseconds, even though it is not explicitly documented in
// https://doc.rust-lang.org/test/stats/struct.Summary.html.
let elapsed = Duration::from_nanos(summary.mean as u64);
// We expect the `mean` time for loading the test Wasm module and running its main function
// to be less than a fixed threshold.
assert!(
elapsed < Duration::from_micros(100),
"elapsed time: {:.0?}",
elapsed
);
}
}

struct TestState {
instance: wasmi::Instance,
store: wasmi::Store<UserState<StandaloneLogger>>,
wasm_handler: WasmHandler<StandaloneLogger>,
request: Vec<u8>,
}

Expand Down Expand Up @@ -155,16 +208,17 @@ fn create_test_state() -> TestState {

let user_state = UserState::new(wasm_api.transport(), logger.clone());

let module = wasm_handler.wasm_module;
let module = wasm_handler.wasm_module.clone();
let mut store = wasmi::Store::new(module.engine(), user_state);
let linker = OakLinker::new(module.engine(), &mut store);
let (instance, store) = linker
.instantiate(store, module)
let linker = OakLinker::new(module.engine());
let instance = linker
.instantiate(&mut store, module)
.expect("couldn't instantiate Wasm module");

TestState {
store,
instance,
store,
wasm_handler,
request: request.clone(),
}
}
Expand Down

0 comments on commit a91172d

Please sign in to comment.