diff --git a/src/parser/WASMParser.cpp b/src/parser/WASMParser.cpp index ac4d44c90..811f1bfed 100644 --- a/src/parser/WASMParser.cpp +++ b/src/parser/WASMParser.cpp @@ -160,7 +160,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { , m_nonOptimizedPosition(nonOptimizedPosition) , m_localIndex(localIndex) { - increaseRefCountIfNeeds(); } VMStackInfo(const VMStackInfo& src) @@ -170,24 +169,19 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { , m_nonOptimizedPosition(src.m_nonOptimizedPosition) , m_localIndex(src.m_localIndex) { - increaseRefCountIfNeeds(); } ~VMStackInfo() { - decreaseRefCountIfNeeds(); } const VMStackInfo& operator=(const VMStackInfo& src) { - decreaseRefCountIfNeeds(); - m_reader = src.m_reader; m_valueType = src.m_valueType; m_position = src.m_position; m_nonOptimizedPosition = src.m_nonOptimizedPosition; m_localIndex = src.m_localIndex; - increaseRefCountIfNeeds(); return *this; } @@ -196,6 +190,11 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { return m_localIndex != std::numeric_limits::max(); } + void clearLocalIndex() + { + m_localIndex = std::numeric_limits::max(); + } + size_t position() const { return m_position; @@ -227,20 +226,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } private: - void increaseRefCountIfNeeds() - { - if (m_localIndex != std::numeric_limits::max()) { - m_reader.m_localInfo[m_localIndex].m_refCount++; - } - } - - void decreaseRefCountIfNeeds() - { - if (m_localIndex != std::numeric_limits::max()) { - m_reader.m_localInfo[m_localIndex].m_refCount--; - } - } - WASMBinaryReader& m_reader; Walrus::Value::Type m_valueType; size_t m_position; // effective position (local values will have different position) @@ -259,7 +244,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { Type m_returnValueType; size_t m_position; std::vector m_vmStack; - std::vector m_parameterPositions; uint32_t m_functionStackSizeSoFar; bool m_shouldRestoreVMStackAtEnd; bool m_byteCodeGenerationStopped; @@ -282,7 +266,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { : m_blockType(type) , m_returnValueType(returnValueType) , m_position(0) - , m_vmStack(binaryReader.m_vmStack) , m_functionStackSizeSoFar(binaryReader.m_functionStackSizeSoFar) , m_shouldRestoreVMStackAtEnd(false) , m_byteCodeGenerationStopped(false) @@ -293,22 +276,29 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { auto endIter = binaryReader.m_vmStack.rbegin() + param.size(); auto iter = binaryReader.m_vmStack.rbegin(); while (iter != endIter) { - m_parameterPositions.push_back(iter->nonOptimizedPosition()); - iter++; - } - - // assign local values which use direct-access into general register - endIter = binaryReader.m_vmStack.rbegin() + param.size(); - iter = binaryReader.m_vmStack.rbegin(); - while (iter != endIter) { - if (iter->position() != iter->nonOptimizedPosition()) { + if (iter->hasValidLocalIndex()) { binaryReader.generateMoveCodeIfNeeds(iter->position(), iter->nonOptimizedPosition(), iter->valueType()); iter->setPosition(iter->nonOptimizedPosition()); + if (binaryReader.m_inPreprocess) { + size_t pos = *binaryReader.m_readerOffsetPointer; + auto localVariableIter = binaryReader.m_localVariableUsage.rbegin(); + while (true) { + ASSERT(localVariableIter != binaryReader.m_localVariableUsage.rend()); + if (localVariableIter->m_localIndex == iter->localIndex() + && localVariableIter->m_endPosition == std::numeric_limits::max()) { + localVariableIter->m_endPosition = *binaryReader.m_readerOffsetPointer; + break; + } + localVariableIter++; + } + } + iter->clearLocalIndex(); } iter++; } } + m_vmStack = binaryReader.m_vmStack; m_position = binaryReader.m_currentFunction->currentByteCodeSize(); } }; @@ -316,13 +306,47 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { size_t* m_readerOffsetPointer; size_t m_codeStartOffset; + struct LocalVariableUsage { + size_t m_localIndex; + size_t m_startPosition; + size_t m_endPosition; + size_t m_pushCount; + bool m_hasWriteUsage; + + LocalVariableUsage(size_t localIndex, size_t startPosition, size_t pushCount) + : m_localIndex(localIndex) + , m_startPosition(startPosition) + , m_endPosition(std::numeric_limits::max()) + , m_pushCount(pushCount) + , m_hasWriteUsage(false) + { + } + }; + + LocalVariableUsage& findNearestUsage(size_t localIndex) + { + size_t pos = *m_readerOffsetPointer; + auto iter = m_localVariableUsage.rbegin(); + while (true) { + ASSERT(iter != m_localVariableUsage.rend()); + if (iter->m_localIndex == localIndex) { + if (iter->m_startPosition <= pos) { + return *iter; + } + } + iter++; + } + ASSERT_NOT_REACHED(); + } + + bool m_inPreprocess; + std::vector m_localVariableUsage; + Walrus::ModuleFunction* m_currentFunction; Walrus::FunctionType* m_currentFunctionType; uint32_t m_initialFunctionStackSize; uint32_t m_functionStackSizeSoFar; uint32_t m_lastByteCodePosition; - WASMOpcode m_lastPushedOpcode; - uint32_t m_lastOpcode[2]; std::vector m_vmStack; std::vector m_blockInfo; @@ -336,19 +360,10 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { std::vector m_catchInfo; struct LocalInfo { Walrus::Value::Type m_valueType; - size_t m_refCount; - bool m_canUseDirectReference; LocalInfo(Walrus::Value::Type type) : m_valueType(type) - , m_refCount(0) - , m_canUseDirectReference(true) { } - - void reset() - { - m_refCount = 0; - } }; std::vector m_localInfo; @@ -373,30 +388,20 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { return pos; } - void resetFunctionCodeDataFromHere() + void pushVMStack(Walrus::Value::Type type, size_t pos, size_t localIndex = std::numeric_limits::max()) { - m_skipValidationUntil = *m_readerOffsetPointer; - *m_readerOffsetPointer = m_codeStartOffset; - - m_currentFunction->m_byteCode.clear(); - m_currentFunction->m_catchInfo.clear(); - m_blockInfo.clear(); - m_catchInfo.clear(); - - m_functionStackSizeSoFar = m_initialFunctionStackSize; - m_lastByteCodePosition = 0; - m_lastPushedOpcode = WASMOpcode::OpcodeKindEnd; - m_lastOpcode[0] = m_lastOpcode[1] = 0; - - m_vmStack.clear(); - - for (auto& info : m_localInfo) { - info.reset(); + if (m_inPreprocess) { + if (localIndex != std::numeric_limits::max()) { + size_t pushCount = 0; + for (const auto& stack : m_vmStack) { + if (stack.localIndex() == localIndex) { + pushCount++; + } + } + m_localVariableUsage.push_back(LocalVariableUsage(localIndex, *m_readerOffsetPointer, pushCount)); + } } - } - void pushVMStack(Walrus::Value::Type type, size_t pos, size_t localIndex = std::numeric_limits::max()) - { m_vmStack.push_back(VMStackInfo(*this, type, pos, m_functionStackSizeSoFar, localIndex)); m_functionStackSizeSoFar += Walrus::valueStackAllocatedSize(type); if (UNLIKELY(m_functionStackSizeSoFar > std::numeric_limits::max())) { @@ -411,6 +416,23 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { auto info = m_vmStack.back(); m_functionStackSizeSoFar -= Walrus::valueStackAllocatedSize(info.valueType()); m_vmStack.pop_back(); + + if (m_inPreprocess) { + if (info.hasValidLocalIndex()) { + size_t pos = *m_readerOffsetPointer; + auto iter = m_localVariableUsage.rbegin(); + while (true) { + ASSERT(iter != m_localVariableUsage.rend()); + if (iter->m_localIndex == info.localIndex() + && iter->m_endPosition == std::numeric_limits::max()) { + iter->m_endPosition = *m_readerOffsetPointer; + break; + } + iter++; + } + } + } + return info; } @@ -445,7 +467,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } m_initialFunctionStackSize = m_functionStackSizeSoFar = m_currentFunctionType->paramStackSize(); m_lastByteCodePosition = 0; - m_lastPushedOpcode = WASMOpcode::OpcodeKindEnd; m_currentFunction->m_requiredStackSize = std::max( m_currentFunction->m_requiredStackSize, m_functionStackSizeSoFar); } @@ -462,33 +483,19 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { void pushByteCode(const CodeType& code, WASMOpcode opcode) { m_lastByteCodePosition = m_currentFunction->currentByteCodeSize(); - m_lastPushedOpcode = opcode; m_currentFunction->pushByteCode(code); } - bool canUseDirectReference(uint32_t localIndex, uint32_t pos) - { - for (const auto& bi : m_blockInfo) { - for (uint32_t p : bi.m_parameterPositions) { - if (pos == p) { - return false; - } - } - } - return m_localInfo[localIndex].m_canUseDirectReference; - } - public: WASMBinaryReader() : m_readerOffsetPointer(nullptr) , m_codeStartOffset(0) + , m_inPreprocess(false) , m_currentFunction(nullptr) , m_currentFunctionType(nullptr) , m_initialFunctionStackSize(0) , m_functionStackSizeSoFar(0) , m_lastByteCodePosition(0) - , m_lastPushedOpcode(WASMOpcode::OpcodeKindEnd) - , m_lastOpcode{ 0, 0 } , m_elementTableIndex(0) , m_segmentMode(Walrus::SegmentMode::None) { @@ -822,10 +829,30 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { m_codeStartOffset = *m_readerOffsetPointer; } + virtual void OnStartPreprocess() override + { + m_inPreprocess = true; + m_localVariableUsage.clear(); + } + + virtual void OnEndPreprocess() override + { + m_inPreprocess = false; + m_skipValidationUntil = *m_readerOffsetPointer - 1; + m_shouldContinueToGenerateByteCode = true; + + m_currentFunction->m_byteCode.clear(); + m_currentFunction->m_catchInfo.clear(); + m_blockInfo.clear(); + m_catchInfo.clear(); + + m_functionStackSizeSoFar = m_initialFunctionStackSize; + m_lastByteCodePosition = 0; + m_vmStack.clear(); + } + virtual void OnOpcode(uint32_t opcode) override { - m_lastOpcode[1] = m_lastOpcode[0]; - m_lastOpcode[0] = opcode; } virtual void OnCallExpr(uint32_t index) override @@ -950,7 +977,19 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { { auto r = resolveLocalOffsetAndSize(localIndex); auto localValueType = m_localInfo[localIndex].m_valueType; - if (canUseDirectReference(localIndex, m_functionStackSizeSoFar)) { + + bool canUseDirectReference = true; + size_t pos = *m_readerOffsetPointer; + for (const auto& r : m_localVariableUsage) { + if (r.m_localIndex == localIndex && r.m_startPosition <= pos && pos <= r.m_endPosition) { + if (r.m_hasWriteUsage) { + canUseDirectReference = false; + break; + } + } + } + + if (canUseDirectReference) { pushVMStack(localValueType, r.first, localIndex); } else { auto pos = m_functionStackSizeSoFar; @@ -959,66 +998,39 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } } - bool omitUpdateLocalValueIfPossible(Index localIndex, std::pair localOffsetAndSize, const VMStackInfo& stack) + void updateWriteUsageOfLocalIfNeeds(Index localIndex) { - if (canUseDirectReference(localIndex, stack.position()) && stack.position() != localOffsetAndSize.first && !stack.hasValidLocalIndex()) { - // we should check last opcode and bytecode are same - // because some opcode omitted by optimization - // eg) (i32.add) (local.get 0) ;; local.get 0 can be omitted by direct access - if (m_lastOpcode[1] == static_cast(m_lastPushedOpcode) && isBinaryOperation(m_lastPushedOpcode)) { - m_currentFunction->peekByteCode(m_lastByteCodePosition)->setDstOffset(localOffsetAndSize.first); - } else if (m_lastPushedOpcode == WASMOpcode::Const32Opcode) { - m_currentFunction->peekByteCode(m_lastByteCodePosition)->setDstOffset(localOffsetAndSize.first); - } else if (m_lastPushedOpcode == WASMOpcode::Const64Opcode) { - m_currentFunction->peekByteCode(m_lastByteCodePosition)->setDstOffset(localOffsetAndSize.first); - } else { - return false; + if (m_inPreprocess) { + size_t pos = *m_readerOffsetPointer; + auto iter = m_localVariableUsage.begin(); + while (iter != m_localVariableUsage.end()) { + if (localIndex == iter->m_localIndex + && iter->m_startPosition <= pos && pos <= iter->m_endPosition) { + iter->m_hasWriteUsage = true; + } + iter++; } - return true; } - return false; } virtual void OnLocalSetExpr(Index localIndex) override { auto r = resolveLocalOffsetAndSize(localIndex); - if (m_localInfo[localIndex].m_refCount && m_localInfo[localIndex].m_canUseDirectReference) { - // src and dst are same! - // example) (local.get 0) (local.set 0) ;; w/direct access - if (peekVMStack() != r.first) { - // rewind generating bytecode - m_localInfo[localIndex].m_canUseDirectReference = false; - resetFunctionCodeDataFromHere(); - return; - } - } ASSERT(m_localInfo[localIndex].m_valueType == peekVMStackValueType()); auto src = popVMStackInfo(); - if (!omitUpdateLocalValueIfPossible(localIndex, r, src)) { - generateMoveCodeIfNeeds(src.position(), r.first, src.valueType()); - } + generateMoveCodeIfNeeds(src.position(), r.first, src.valueType()); + updateWriteUsageOfLocalIfNeeds(localIndex); } virtual void OnLocalTeeExpr(Index localIndex) override { - if (m_localInfo[localIndex].m_refCount && m_localInfo[localIndex].m_canUseDirectReference) { - m_localInfo[localIndex].m_canUseDirectReference = false; - resetFunctionCodeDataFromHere(); - return; - } - auto valueType = m_localInfo[localIndex].m_valueType; auto r = resolveLocalOffsetAndSize(localIndex); ASSERT(valueType == peekVMStackValueType()); auto dstInfo = peekVMStackInfo(); - - if (omitUpdateLocalValueIfPossible(localIndex, r, dstInfo)) { - auto oldInfo = popVMStackInfo(); - pushVMStack(oldInfo.valueType(), r.first, localIndex); - } else { - generateMoveCodeIfNeeds(dstInfo.position(), r.first, valueType); - } + generateMoveCodeIfNeeds(dstInfo.position(), r.first, valueType); + updateWriteUsageOfLocalIfNeeds(localIndex); } virtual void OnGlobalGetExpr(Index index) override @@ -1121,6 +1133,13 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { void restoreVMStackBy(const BlockInfo& blockInfo) { + if (blockInfo.m_vmStack.size() <= m_vmStack.size()) { + size_t diff = m_vmStack.size() - blockInfo.m_vmStack.size(); + for (size_t i = 0; i < diff; i++) { + popVMStack(); + } + ASSERT(blockInfo.m_vmStack.size() == m_vmStack.size()); + } m_vmStack = blockInfo.m_vmStack; m_functionStackSizeSoFar = blockInfo.m_functionStackSizeSoFar; } @@ -1209,8 +1228,13 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { if (m_blockInfo.size()) { m_resumeGenerateByteCodeAfterNBlockEnd = 1; - m_blockInfo.back().m_shouldRestoreVMStackAtEnd = true; - m_blockInfo.back().m_byteCodeGenerationStopped = true; + auto& blockInfo = m_blockInfo.back(); + blockInfo.m_shouldRestoreVMStackAtEnd = true; + blockInfo.m_byteCodeGenerationStopped = true; + } else { + while (m_vmStack.size()) { + popVMStack(); + } } m_shouldContinueToGenerateByteCode = false; } @@ -1316,7 +1340,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } } - void generateEndCode() + void generateEndCode(bool shouldClearVMStack = false) { if (UNLIKELY(m_currentFunctionType->result().size() > m_vmStack.size())) { // error case of global init expr @@ -1331,6 +1355,12 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { for (size_t i = 0; i < result.size(); i++) { end->resultOffsets()[result.size() - i - 1] = (m_vmStack.rbegin() + i)->position(); } + + if (shouldClearVMStack) { + for (size_t i = 0; i < result.size(); i++) { + popVMStack(); + } + } } void generateFunctionReturnCode(bool shouldClearVMStack = false) @@ -1837,7 +1867,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } } } else { - generateEndCode(); + generateEndCode(true); } } diff --git a/third_party/wabt/include/wabt/binary-reader.h b/third_party/wabt/include/wabt/binary-reader.h index 4c08c3ce2..2cef054f8 100644 --- a/third_party/wabt/include/wabt/binary-reader.h +++ b/third_party/wabt/include/wabt/binary-reader.h @@ -190,6 +190,10 @@ class BinaryReaderDelegate { virtual Result OnLocalDecl(Index decl_index, Index count, Type type) = 0; virtual Result OnStartReadInstructions() { return Result::Ok; } + virtual bool NeedsPreprocess() { return false; } + virtual Result OnStartPreprocess() { return Result::Ok; } + virtual Result OnEndPreprocess() { return Result::Ok; } + /* Function expressions; called between BeginFunctionBody and EndFunctionBody */ virtual Result OnOpcode(Opcode Opcode) = 0; diff --git a/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h b/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h index c1f377025..24838c99b 100644 --- a/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h +++ b/third_party/wabt/include/wabt/walrus/binary-reader-walrus.h @@ -98,6 +98,8 @@ class WASMBinaryReaderDelegate { virtual void OnLocalDecl(Index decl_index, Index count, Type type) = 0; virtual void OnStartReadInstructions() = 0; + virtual void OnStartPreprocess() = 0; + virtual void OnEndPreprocess() = 0; virtual void OnOpcode(uint32_t opcode) = 0; diff --git a/third_party/wabt/src/binary-reader.cc b/third_party/wabt/src/binary-reader.cc index 390ce45a9..0f2f58ab3 100644 --- a/third_party/wabt/src/binary-reader.cc +++ b/third_party/wabt/src/binary-reader.cc @@ -678,7 +678,22 @@ Result BinaryReader::ReadInstructions(bool stop_on_end, Opcode* final_opcode) { CALLBACK(OnStartReadInstructions); - while (state_.offset < end_offset) { + auto start_offset = state_.offset; + bool in_preprocess = delegate_->NeedsPreprocess(); + if (in_preprocess) { + CALLBACK(OnStartPreprocess); + } + + while (true) { + if (state_.offset >= end_offset) { + if (in_preprocess) { + CALLBACK(OnEndPreprocess); + in_preprocess = false; + state_.offset = start_offset; + } else { + break; + } + } Opcode opcode; CHECK_RESULT(ReadOpcode(&opcode, "opcode")); CALLBACK(OnOpcode, opcode); @@ -814,7 +829,13 @@ Result BinaryReader::ReadInstructions(bool stop_on_end, case Opcode::End: CALLBACK0(OnEndExpr); if (stop_on_end) { - return Result::Ok; + if (in_preprocess) { + CALLBACK(OnEndPreprocess); + in_preprocess = false; + state_.offset = start_offset; + } else { + return Result::Ok; + } } break; diff --git a/third_party/wabt/src/walrus/binary-reader-walrus.cc b/third_party/wabt/src/walrus/binary-reader-walrus.cc index 7e4362b1e..005b77e38 100644 --- a/third_party/wabt/src/walrus/binary-reader-walrus.cc +++ b/third_party/wabt/src/walrus/binary-reader-walrus.cc @@ -411,6 +411,20 @@ class BinaryReaderDelegateWalrus: public BinaryReaderDelegate { return Result::Ok; } + Result OnStartPreprocess() override { + m_externalDelegate->OnStartPreprocess(); + return Result::Ok; + } + + Result OnEndPreprocess() override { + m_externalDelegate->OnEndPreprocess(); + return Result::Ok; + } + + bool NeedsPreprocess() override { + return true; + } + /* Function expressions; called between BeginFunctionBody and EndFunctionBody */ Result OnOpcode(Opcode opcode) override {