Skip to content

Commit

Permalink
Implement br_table and return
Browse files Browse the repository at this point in the history
Signed-off-by: Seonghyun Kim <[email protected]>
  • Loading branch information
ksh8281 authored and clover2123 committed Oct 31, 2023
1 parent 0c14ef0 commit 78d2707
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 19 deletions.
42 changes: 42 additions & 0 deletions src/interpreter/ByteCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,48 @@ class Select : public ByteCode {
uint32_t m_size;
};

class BrTable : public ByteCode {
public:
BrTable(uint32_t m_tableSize)
: ByteCode(OpcodeKind::BrTableOpcode)
, m_defaultOffset(0)
, m_tableSize(m_tableSize)
{
}

int32_t defaultOffset() const { return m_defaultOffset; }
void setDefaultOffset(int32_t offset)
{
m_defaultOffset = offset;
}

uint32_t tableSize() const { return m_tableSize; }
int32_t* jumpOffsets() const
{
return reinterpret_cast<int32_t*>(reinterpret_cast<size_t>(this) + sizeof(BrTable));
}

#if !defined(NDEBUG)
virtual void dump(size_t pos)
{
printf("tableSize: %" PRIu32 ", defaultOffset: %" PRId32, m_tableSize, m_defaultOffset);
printf(" table contents: ");
for (size_t i = 0; i < m_tableSize; i++) {
printf("%zu->%" PRId32 " ", i, jumpOffsets()[i]);
}
}

virtual size_t byteCodeSize()
{
return sizeof(BrTable) + sizeof(int32_t) * m_tableSize;
}
#endif

protected:
int32_t m_defaultOffset;
uint32_t m_tableSize;
};

class MemorySize : public ByteCode {
public:
MemorySize(uint32_t index)
Expand Down
17 changes: 16 additions & 1 deletion src/interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ void Interpreter::interpret(ExecutionState& state,
{
JumpIfFalse* code = (JumpIfFalse*)programCounter;
if (readValue<int32_t>(sp)) {
ADD_PROGRAM_COUNTER(JumpIfTrue);
ADD_PROGRAM_COUNTER(JumpIfFalse);
} else {
programCounter += code->offset();
}
Expand All @@ -506,6 +506,21 @@ void Interpreter::interpret(ExecutionState& state,
NEXT_INSTRUCTION();
}

DEFINE_OPCODE(BrTable)
:
{
BrTable* code = (BrTable*)programCounter;
uint32_t value = readValue<uint32_t>(sp);

if (value >= code->tableSize()) {
// default case
programCounter += code->defaultOffset();
} else {
programCounter += code->jumpOffsets()[value];
}
NEXT_INSTRUCTION();
}

DEFINE_OPCODE(MemorySize)
:
{
Expand Down
110 changes: 94 additions & 16 deletions src/parser/WASMParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
BlockType m_blockType;
Type m_returnValueType;
size_t m_position;
size_t m_stackPushCount;

static_assert(sizeof(Walrus::JumpIfTrue) == sizeof(Walrus::JumpIfFalse), "");
struct JumpToEndBrInfo {
Expand All @@ -82,6 +83,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
: m_blockType(type)
, m_returnValueType(returnValueType)
, m_position(0)
, m_stackPushCount(0)
{
}
};
Expand Down Expand Up @@ -143,22 +145,22 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
new Walrus::String(fieldName), funcIndex, sigIndex));
}

virtual void OnExportCount(Index count)
virtual void OnExportCount(Index count) override
{
m_module->m_export.reserve(count);
}

virtual void OnExport(int kind, Index exportIndex, std::string name, Index itemIndex)
virtual void OnExport(int kind, Index exportIndex, std::string name, Index itemIndex) override
{
m_module->m_export.pushBack(new Walrus::ModuleExport(static_cast<Walrus::ModuleExport::Type>(kind), new Walrus::String(name), exportIndex, itemIndex));
}

virtual void OnMemoryCount(Index count)
virtual void OnMemoryCount(Index count) override
{
m_module->m_memory.reserve(count);
}

virtual void OnMemory(Index index, size_t initialSize, size_t maximumSize)
virtual void OnMemory(Index index, size_t initialSize, size_t maximumSize) override
{
ASSERT(index == m_module->m_memory.size());
m_module->m_memory.pushBack(std::make_pair(initialSize, maximumSize));
Expand All @@ -185,6 +187,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
virtual void BeginFunctionBody(Index index, Offset size) override
{
ASSERT(m_currentFunction == nullptr);
m_shouldContinueToGenerateByteCode = true;
m_currentFunction = m_module->function(index);
m_currentFunctionType = m_module->functionType(m_currentFunction->functionTypeIndex());
m_functionStackSizeSoFar = m_currentFunctionType->paramStackSize();
Expand Down Expand Up @@ -340,15 +343,16 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {

virtual void OnIfExpr(Type sigType) override
{
ASSERT(peekVMStack() == Walrus::valueSizeInStack(toValueKindForLocalType(Type::I32)));
popVMStack();

BlockInfo b(BlockInfo::IfElse, sigType);
b.m_position = m_currentFunction->currentByteCodeSize();
b.m_jumpToEndBrInfo.push_back({ true, b.m_position });
b.m_stackPushCount = m_vmStack.size();
m_blockInfo.push_back(b);
m_currentFunction->pushByteCode(Walrus::JumpIfFalse());

ASSERT(peekVMStack() == Walrus::valueSizeInStack(toValueKindForLocalType(Type::I32)));
popVMStack();

if (sigType != Type::Void) {
pushVMStack(Walrus::valueSizeInStack(toValueKindForLocalType(sigType)));
}
Expand All @@ -373,6 +377,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
{
BlockInfo b(BlockInfo::Loop, sigType);
b.m_position = m_currentFunction->currentByteCodeSize();
b.m_stackPushCount = m_vmStack.size();
m_blockInfo.push_back(b);

if (sigType != Type::Void) {
Expand All @@ -384,6 +389,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
{
BlockInfo b(BlockInfo::Block, sigType);
b.m_position = m_currentFunction->currentByteCodeSize();
b.m_stackPushCount = m_vmStack.size();
m_blockInfo.push_back(b);

if (sigType != Type::Void) {
Expand All @@ -405,20 +411,46 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
size_t dropStackValuesBeforeBrIfNeeds(Index depth)
{
size_t dropValueSize = 0;
auto iter = m_blockInfo.rbegin();
while (depth) {
if (depth < m_blockInfo.size()) {
auto iter = m_blockInfo.rbegin() + depth;

size_t start = iter->m_stackPushCount;
for (size_t i = start; i < m_vmStack.size(); i++) {
dropValueSize += m_vmStack[i];
}

if (iter->m_returnValueType != Type::Void) {
dropValueSize += Walrus::valueSizeInStack(toValueKindForLocalType(iter->m_returnValueType));
dropValueSize -= Walrus::valueSizeInStack(toValueKindForLocalType(iter->m_returnValueType));
}
} else if (m_blockInfo.size()) {
auto iter = m_blockInfo.begin();
size_t start = iter->m_stackPushCount;
for (size_t i = start; i < m_vmStack.size(); i++) {
dropValueSize += m_vmStack[i];
}
iter++;
depth--;
}

return dropValueSize;
}

virtual void OnBrExpr(Index depth)
virtual void OnBrExpr(Index depth) override
{
if (m_blockInfo.size() == depth) {
// this case acts like return
for (size_t i = 0; i < m_currentFunctionType->result().size(); i++) {
ASSERT(*(m_vmStack.rbegin() + i) == Walrus::valueSizeInStack(m_currentFunctionType->result()[m_currentFunctionType->result().size() - i - 1]));
}
m_currentFunction->pushByteCode(Walrus::End());
auto dropSize = dropStackValuesBeforeBrIfNeeds(depth);
while (dropSize) {
dropSize -= popVMStack();
}

if (!m_blockInfo.size()) {
// stop to generate bytecode from here!
m_shouldContinueToGenerateByteCode = false;
}
return;
}
auto& blockInfo = findBlockInfoInBr(depth);
auto offset = (int32_t)blockInfo.m_position - (int32_t)m_currentFunction->currentByteCodeSize();
auto dropSize = dropStackValuesBeforeBrIfNeeds(depth);
Expand All @@ -431,8 +463,24 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
m_currentFunction->pushByteCode(Walrus::Jump(offset));
}

virtual void OnBrIfExpr(Index depth)
virtual void OnBrIfExpr(Index depth) override
{
if (m_blockInfo.size() == depth) {
// this case acts like return
ASSERT(peekVMStack() == Walrus::valueSizeInStack(toValueKindForLocalType(Type::I32)));
size_t pos = m_currentFunction->currentByteCodeSize();
m_currentFunction->pushByteCode(Walrus::JumpIfFalse(sizeof(Walrus::JumpIfFalse) + sizeof(Walrus::End)));
m_currentFunction->pushByteCode(Walrus::End());
for (size_t i = 0; i < m_currentFunctionType->result().size(); i++) {
ASSERT(*(m_vmStack.rbegin() + i) == Walrus::valueSizeInStack(m_currentFunctionType->result()[m_currentFunctionType->result().size() - i - 1]));
}
popVMStack();
return;
}

ASSERT(peekVMStack() == Walrus::valueSizeInStack(toValueKindForLocalType(Type::I32)));
popVMStack();

auto& blockInfo = findBlockInfoInBr(depth);
auto dropSize = dropStackValuesBeforeBrIfNeeds(depth);
if (dropSize) {
Expand All @@ -453,12 +501,37 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
}
m_currentFunction->pushByteCode(Walrus::JumpIfTrue(offset));
}
}

virtual void OnBrTableExpr(Index numTargets, Index* targetDepths, Index defaultTargetDepth) override
{
ASSERT(peekVMStack() == Walrus::valueSizeInStack(toValueKindForLocalType(Type::I32)));
popVMStack();

size_t brTableCode = m_currentFunction->currentByteCodeSize();
m_currentFunction->pushByteCode(Walrus::BrTable(numTargets));

if (numTargets) {
m_currentFunction->expandByteCode(sizeof(int32_t) * numTargets);
std::vector<size_t> offsets;

for (Index i = 0; i < numTargets; i++) {
offsets.push_back(m_currentFunction->currentByteCodeSize() - brTableCode);
OnBrExpr(targetDepths[i]);
}

for (Index i = 0; i < numTargets; i++) {
m_currentFunction->peekByteCode<Walrus::BrTable>(brTableCode)->jumpOffsets()[i] = offsets[i];
}
}

// generate default
size_t pos = m_currentFunction->currentByteCodeSize();
OnBrExpr(defaultTargetDepth);
m_currentFunction->peekByteCode<Walrus::BrTable>(brTableCode)->setDefaultOffset(pos - brTableCode);
}

virtual void OnSelectExpr(Index resultCount, Type* resultTypes)
virtual void OnSelectExpr(Index resultCount, Type* resultTypes) override
{
// TODO implement selectT
ASSERT(resultCount == 0);
Expand Down Expand Up @@ -491,6 +564,11 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
{
}

virtual void OnReturnExpr() override
{
OnBrExpr(m_blockInfo.size());
}

virtual void OnEndExpr() override
{
if (m_blockInfo.size()) {
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class ModuleFunction : public gc {
return reinterpret_cast<CodeType*>(&m_byteCode[position]);
}

void expandByteCode(size_t s)
{
m_byteCode.resizeWithUninitializedValues(m_byteCode.size() + s);
}

size_t currentByteCodeSize() const
{
return m_byteCode.size();
Expand Down
44 changes: 44 additions & 0 deletions test/basic/br.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
(module
(memory 7)
(func (export "br0")(result i32 i32 i32)
i32.const 1
i32.const 2
i32.const 3
br 0
i32.const 100 ;; the bytecode is not generated from here to function end
drop
drop
drop
memory.grow
)
(func (export "br_if0")(param i32)(result i32 i32 i32)
i32.const 1
i32.const 2
i32.const 3
local.get 0
br_if 0
drop
i32.const 100
memory.grow
)
(func (export "check")(result i32)
memory.size
)
(func (export "br0_1")(param i32)(result i32)
local.get 0
(if (then
(i32.const 100)
(br 1)
))
i32.const 200
)
)

(assert_return (invoke "br0") (i32.const 1)(i32.const 2)(i32.const 3))
(assert_return (invoke "br_if0"(i32.const 1)) (i32.const 1)(i32.const 2)(i32.const 3))
(assert_return (invoke "check") (i32.const 7))
(assert_return (invoke "br_if0"(i32.const 0)) (i32.const 1)(i32.const 2)(i32.const 7))
(assert_return (invoke "check") (i32.const 107))
(assert_return (invoke "br0_1"(i32.const 1))(i32.const 100))
(assert_return (invoke "br0_1"(i32.const 0))(i32.const 200))

Loading

0 comments on commit 78d2707

Please sign in to comment.