From 13f625ab40a7319be65ef937306e6e67cc3402fc Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Tue, 16 May 2023 13:55:30 -0700 Subject: [PATCH] more fixes Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectTransformer.cpp | 33 +++++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 8abfc3b6250..88db43c866d 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -153,6 +153,13 @@ class FrontendGenImpl { } } + int64_t GetDomainVersion(const std::string &domain) { + auto it = opset_map_.find(domain); + if (it == opset_map_.end()) + return 0; + return it->second; + } + void BindOnnxName(const std::string &onnx_name, Value symbol) { frontend_symbols_.AddMapping(onnx_name, symbol); } @@ -694,20 +701,21 @@ class FrontendGenImpl { } const onnx::OpSchema *GetOpSchema(const onnx::NodeProto &node) { - auto &domain = node.domain(); - auto version_it = opset_map_.find(domain); - if (version_it == opset_map_.end()) + int64_t version = GetDomainVersion(node.domain()); + if (version == 0) return nullptr; - auto version = version_it->second; - return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain); + return onnx::OpSchemaRegistry::Schema( + node.op_type(), version, node.domain()); } std::string GetImportVersionOfNode(const onnx::NodeProto &node) { - auto current_opset = opset_map_.find(node.domain())->second; + int64_t version = GetDomainVersion(node.domain()); + if (version == 0) + return ""; LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX" << node.op_type() << " (" << node.name() << ")" - << ", Opset: " << current_opset << "\n"); + << ", Opset: " << version << "\n"); auto opset_list_it = op_dialect_version_map_.find(node.op_type()); @@ -723,23 +731,20 @@ class FrontendGenImpl { // It is the current opset when onnx-mlir project is started. // All opset lower than the last opset should use the last opset(version) if (node.domain().compare("ai.onnx.ml") != 0 && - current_opset < opset_list.back() && - current_opset < MINIMUM_SUPPORTED_OPSET) + version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET) llvm::outs() << "Warning: ONNX " << node.op_type() - << " in your model is using Opset " << current_opset + << " in your model is using Opset " << version << ", which is quite old. Please consider regenerating your " "model with a newer Opset.\n"; for (int i = opset_list.size() - 1; i > 0; i--) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - testing Opset " - << opset_list[i - 1] << "\n"); - if (current_opset < opset_list[i - 1]) { + if (version < opset_list[i - 1]) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset " << opset_list[i] << "\n"); return "V" + std::to_string(opset_list[i]); } } - return std::string(""); + return ""; } func::FuncOp CreateFuncOp(