Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve function call performance #159

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 18 additions & 62 deletions src/interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,53 +150,10 @@ ByteCodeTable::ByteCodeTable()
b.m_opcodeInAddress = const_cast<void*>(FillByteCodeOpcodeAddress[0]);
#endif
size_t pc = reinterpret_cast<size_t>(&b);
Interpreter::interpret(dummyState, pc, nullptr, nullptr, nullptr, nullptr, nullptr);
Interpreter::interpret(dummyState, pc, nullptr, nullptr);
#endif
}

ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
uint8_t* bp)
{
DefinedFunction* df = state.currentFunction()->asDefinedFunction();
ModuleFunction* mf = df->moduleFunction();
size_t programCounter = reinterpret_cast<size_t>(mf->byteCode());
Instance* instance = df->instance();
while (true) {
try {
return interpret(state, programCounter, bp, instance, instance->m_memories, instance->m_tables, instance->m_globals);
} catch (std::unique_ptr<Exception>& e) {
for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) {
if (e->m_programCounterInfo[i - 1].first == &state) {
programCounter = e->m_programCounterInfo[i - 1].second;
break;
}
}
if (e->isUserException()) {
bool isCatchSucessful = false;
Tag* tag = e->tag().value();
size_t offset = programCounter - reinterpret_cast<size_t>(mf->byteCode());
for (const auto& item : mf->catchInfo()) {
if (item.m_tryStart <= offset && offset < item.m_tryEnd) {
if (item.m_tagIndex == std::numeric_limits<uint32_t>::max() || state.currentFunction()->asDefinedFunction()->instance()->tag(item.m_tagIndex) == tag) {
programCounter = item.m_catchStartPosition + reinterpret_cast<size_t>(mf->byteCode());
uint8_t* sp = bp + item.m_stackSizeToBe;
if (item.m_tagIndex != std::numeric_limits<uint32_t>::max() && tag->functionType()->paramStackSize()) {
memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize());
}
isCatchSucessful = true;
break;
}
}
}
if (isCatchSucessful) {
continue;
}
}
throw std::unique_ptr<Exception>(std::move(e));
}
}
}

template <typename T>
ALWAYS_INLINE void writeValue(uint8_t* bp, ByteCodeStackOffset offset, const T& v)
{
Expand Down Expand Up @@ -488,11 +445,10 @@ static void initAddressToOpcodeTable()
ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
size_t programCounter,
uint8_t* bp,
Instance* instance,
Memory** memories,
Table** tables,
Global** globals)
Instance* instance)
{
Memory** memories = reinterpret_cast<Memory**>(reinterpret_cast<uintptr_t>(instance) + Instance::alignedSize());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you initialize memories as instance->m_memories?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is it enough to initialize local memories only?
Other structures like tables, globals don't need local variables since they are not accessed frequently like memories, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you initialize memories as instance->m_memories?
-> I tested two way of init memories.
. access instance->m_memories -> needs memory access
. computing address(reinterpret_cast<Memory**>(reinterpret_cast<uintptr_t>(instance) + Instance::alignedSize())) -> needs ALU on CPU

the second way shows better performance when I tested

And is it enough to initialize local memories only?
-> Yes. only memories is used frequently when I tested
but globals and tables are don't. but it needs a space on stack


state.m_programCounterPointer = &programCounter;

#define ADD_PROGRAM_COUNTER(codeName) programCounter += sizeof(codeName);
Expand Down Expand Up @@ -947,7 +903,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet32* code = (GlobalGet32*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet32);
NEXT_INSTRUCTION();
}
Expand All @@ -957,7 +913,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet64* code = (GlobalGet64*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet64);
NEXT_INSTRUCTION();
}
Expand All @@ -967,7 +923,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet128* code = (GlobalGet128*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet128);
NEXT_INSTRUCTION();
}
Expand All @@ -977,7 +933,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet32* code = (GlobalSet32*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<4>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet32);
NEXT_INSTRUCTION();
Expand All @@ -988,7 +944,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet64* code = (GlobalSet64*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<8>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet64);
NEXT_INSTRUCTION();
Expand All @@ -999,7 +955,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet128* code = (GlobalSet128*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<16>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet128);
NEXT_INSTRUCTION();
Expand Down Expand Up @@ -1161,7 +1117,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableGet* code = (TableGet*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
void* val = table->getElement(state, readValue<uint32_t>(bp, code->srcOffset()));
writeValue(bp, code->dstOffset(), val);

Expand All @@ -1174,7 +1130,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableSet* code = (TableSet*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
void* ptr = readValue<void*>(bp, code->src1Offset());
table->setElement(state, readValue<uint32_t>(bp, code->src0Offset()), ptr);

Expand All @@ -1187,7 +1143,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableGrow* code = (TableGrow*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
size_t size = table->size();

uint64_t newSize = (uint64_t)readValue<uint32_t>(bp, code->src1Offset()) + size;
Expand All @@ -1210,7 +1166,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableSize* code = (TableSize*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
size_t size = table->size();
writeValue<uint32_t>(bp, code->dstOffset(), size);

Expand All @@ -1224,8 +1180,8 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
TableCopy* code = (TableCopy*)programCounter;
ASSERT(code->dstIndex() < instance->module()->numberOfTableTypes());
ASSERT(code->srcIndex() < instance->module()->numberOfTableTypes());
Table* dstTable = tables[code->dstIndex()];
Table* srcTable = tables[code->srcIndex()];
Table* dstTable = instance->m_tables[code->dstIndex()];
Table* srcTable = instance->m_tables[code->srcIndex()];

uint32_t dstIndex = readValue<uint32_t>(bp, code->srcOffsets()[0]);
uint32_t srcIndex = readValue<uint32_t>(bp, code->srcOffsets()[1]);
Expand All @@ -1242,7 +1198,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableFill* code = (TableFill*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];

int32_t index = readValue<int32_t>(bp, code->srcOffsets()[0]);
void* ptr = readValue<void*>(bp, code->srcOffsets()[1]);
Expand All @@ -1264,7 +1220,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
int32_t size = readValue<int32_t>(bp, code->srcOffsets()[2]);

ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
table->init(state, instance, &sg, dstStart, srcStart, size);
ADD_PROGRAM_COUNTER(TableInit);
NEXT_INSTRUCTION();
Expand Down
79 changes: 71 additions & 8 deletions src/interpreter/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#define __WalrusInterpreter__

#include "runtime/ExecutionState.h"
#include "runtime/Function.h"
#include "runtime/Instance.h"
#include "runtime/Module.h"
#include "runtime/Tag.h"
#include "interpreter/ByteCode.h"

namespace Walrus {
Expand All @@ -28,19 +32,78 @@ class Table;
class Global;

class Interpreter {
public:
static ByteCodeStackOffset* interpret(ExecutionState& state,
uint8_t* bp);

private:
friend class ByteCodeTable;
friend class DefinedFunction;
friend class DefinedFunctionWithTryCatch;

template <const bool considerException>
ALWAYS_INLINE static void callInterpreter(ExecutionState& state, DefinedFunction* function, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
ExecutionState newState(state, function);
CHECK_STACK_LIMIT(newState);

auto moduleFunction = function->moduleFunction();
ALLOCA(uint8_t, functionStackBase, moduleFunction->requiredStackSize());

// init parameter space
for (size_t i = 0; i < parameterOffsetCount; i++) {
((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i]));
}

size_t programCounter = reinterpret_cast<size_t>(moduleFunction->byteCode());
ByteCodeStackOffset* resultOffsets;
if (considerException) {
while (true) {
try {
resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance());
break;
} catch (std::unique_ptr<Exception>& e) {
for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) {
if (e->m_programCounterInfo[i - 1].first == &newState) {
programCounter = e->m_programCounterInfo[i - 1].second;
break;
}
}
if (e->isUserException()) {
bool isCatchSucessful = false;
Tag* tag = e->tag().value();
size_t offset = programCounter - reinterpret_cast<size_t>(moduleFunction->byteCode());
for (const auto& item : moduleFunction->catchInfo()) {
if (item.m_tryStart <= offset && offset < item.m_tryEnd) {
if (item.m_tagIndex == std::numeric_limits<uint32_t>::max() || function->instance()->tag(item.m_tagIndex) == tag) {
programCounter = item.m_catchStartPosition + reinterpret_cast<size_t>(moduleFunction->byteCode());
uint8_t* sp = functionStackBase + item.m_stackSizeToBe;
if (item.m_tagIndex != std::numeric_limits<uint32_t>::max() && tag->functionType()->paramStackSize()) {
memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize());
}
isCatchSucessful = true;
break;
}
}
}
if (isCatchSucessful) {
continue;
}
}
throw std::unique_ptr<Exception>(std::move(e));
}
}
} else {
resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance());
}

offsets += parameterOffsetCount;
for (size_t i = 0; i < resultOffsetCount; i++) {
*((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i]));
}
}

static ByteCodeStackOffset* interpret(ExecutionState& state,
size_t programCounter,
uint8_t* bp,
Instance* instance,
Memory** memories,
Table** tables,
Global** globals);
Instance* instance);

static void callOperation(ExecutionState& state,
size_t& programCounter,
Expand Down
5 changes: 3 additions & 2 deletions src/parser/WASMParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
bool m_inInitExpr;
Walrus::ModuleFunction* m_currentFunction;
Walrus::FunctionType* m_currentFunctionType;
uint32_t m_initialFunctionStackSize;
uint32_t m_functionStackSizeSoFar;
uint16_t m_initialFunctionStackSize;
uint16_t m_functionStackSizeSoFar;

std::vector<VMStackInfo> m_vmStack;
std::vector<BlockInfo> m_blockInfo;
Expand Down Expand Up @@ -1879,6 +1879,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
{
BlockInfo b(BlockInfo::TryCatch, sigType, *this);
m_blockInfo.push_back(b);
m_currentFunction->m_hasTryCatch = true;
}

void processCatchExpr(Index tagIndex)
Expand Down
30 changes: 14 additions & 16 deletions src/runtime/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "runtime/Store.h"
#include "interpreter/Interpreter.h"
#include "runtime/Module.h"
#include "runtime/Tag.h"
#include "runtime/Instance.h"
#include "runtime/Value.h"

namespace Walrus {
Expand All @@ -28,7 +30,12 @@ DefinedFunction* DefinedFunction::createDefinedFunction(Store* store,
Instance* instance,
ModuleFunction* moduleFunction)
{
DefinedFunction* func = new DefinedFunction(instance, moduleFunction);
DefinedFunction* func;
if (moduleFunction->hasTryCatch()) {
func = new DefinedFunctionWithTryCatch(instance, moduleFunction);
} else {
func = new DefinedFunction(instance, moduleFunction);
}
store->appendExtern(func);
return func;
}
Expand Down Expand Up @@ -88,22 +95,13 @@ void DefinedFunction::call(ExecutionState& state, Value* argv, Value* result)
void DefinedFunction::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
ExecutionState newState(state, this);
CHECK_STACK_LIMIT(newState);

ALLOCA(uint8_t, functionStackBase, m_moduleFunction->requiredStackSize());

// init parameter space
for (size_t i = 0; i < parameterOffsetCount; i++) {
((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i]));
}

auto resultOffsets = Interpreter::interpret(newState, functionStackBase);
Interpreter::callInterpreter<false>(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount);
}

offsets += parameterOffsetCount;
for (size_t i = 0; i < resultOffsetCount; i++) {
*((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i]));
}
void DefinedFunctionWithTryCatch::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
Interpreter::callInterpreter<true>(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount);
}

ImportedFunction* ImportedFunction::createImportedFunction(Store* store,
Expand Down
16 changes: 16 additions & 0 deletions src/runtime/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ class DefinedFunction : public Function {
ModuleFunction* m_moduleFunction;
};

class DefinedFunctionWithTryCatch : public DefinedFunction {
friend class DefinedFunction;
friend class Module;

public:
virtual void interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount) override;

protected:
DefinedFunctionWithTryCatch(Instance* instance,
ModuleFunction* moduleFunction)
: DefinedFunction(instance, moduleFunction)
{
}
};

class ImportedFunction : public Function {
public:
typedef std::function<void(ExecutionState& state, Value* argv, Value* result, void* data)> ImportedFunctionCallback;
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ void Instance::freeInstance(Instance* instance)

Instance::Instance(Module* module)
: m_module(module)
, m_memories(nullptr)
, m_globals(nullptr)
, m_tables(nullptr)
, m_functions(nullptr)
, m_tags(nullptr)
{
module->store()->appendInstance(this);
}
Expand Down
Loading