From 08acaceefd967dc070d551e4c2a8ba6566cc308e Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:49:12 +0530 Subject: [PATCH] [IREE-EP] Register IREE EP in OnnxRT's python bindings. (#4) * [IREE-EP] Register IREE EP in OnnxRT's python bindings. * Add TODO comments. --- .../core/providers/iree/iree_execution_provider.h | 1 + onnxruntime/core/session/provider_registration.cc | 2 +- onnxruntime/python/onnxruntime_pybind_schema.cc | 3 +++ onnxruntime/python/onnxruntime_pybind_state.cc | 10 ++++++++++ onnxruntime/test/perftest/command_args_parser.cc | 2 ++ onnxruntime/test/perftest/ort_test_session.cc | 3 ++- 6 files changed, 19 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.h b/onnxruntime/core/providers/iree/iree_execution_provider.h index 30945685d8dc6..d81d5909678d3 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.h +++ b/onnxruntime/core/providers/iree/iree_execution_provider.h @@ -11,6 +11,7 @@ #include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" #include "core/providers/iree/iree_ep_runtime.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 4cb7c56856867..860b8afca8a97 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -157,7 +157,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "IREE") == 0) { #if defined(USE_IREE) - options->provider_factories.push_back(IREEProviderFactoryCreator::Create(provider_options_keys)); + options->provider_factories.push_back(IREEProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index c5757095e2e1e..9082cc9067078 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -53,6 +53,9 @@ void addGlobalSchemaFunctions(pybind11::module& m) { #ifdef USE_VITISAI onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}), #endif +#ifdef USE_IREE + onnxruntime::IREEProviderFactoryCreator::Create(ProviderOptions{}), +#endif #ifdef USE_ACL onnxruntime::ACLProviderFactoryCreator::Create(0), #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22aea..9a74b527ef6ee 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1138,6 +1138,16 @@ std::unique_ptr CreateExecutionProviderInstance( info["ep_context_embed_mode"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); info["ep_context_file_path"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider(); +#endif + } else if (type == kIreeExecutionProvider) { +#if USE_IREE + const auto &it = provider_options_map.find(type); + ProviderOptions iree_option_map = ProviderOptions{}; + if (it != provider_options_map.end()) { + iree_option_map = it->second; + } + return onnxruntime::IREEProviderFactoryCreator::Create(iree_option_map) + ->CreateProvider(); #endif } else if (type == kAclExecutionProvider) { #ifdef USE_ACL diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d06bbadbd645..43f1e770b5c72 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -280,6 +280,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kXnnpackExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("iree"))) { + test_config.machine_config.provider_type_name = onnxruntime::kIreeExecutionProvider; } else { return false; } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 837aeb3c37acd..38496893d2c05 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -582,7 +582,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (!provider_name_.empty() && provider_name_ != onnxruntime::kCpuExecutionProvider && - provider_name_ != onnxruntime::kOpenVINOExecutionProvider) { + provider_name_ != onnxruntime::kOpenVINOExecutionProvider && + provider_name_ != onnxruntime::kIreeExecutionProvider) { ORT_THROW("This backend is not included in perf test runner.\n"); }