Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IREE-EP] Integrate iree async module in the IREE-EP #15

Draft
wants to merge 1 commit into
base: onnxrt-rebase
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 136 additions & 79 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/iree/iree_ep_runtime.h"

#include "core/session/onnxruntime_cxx_api.h"
#include <iostream>

Check warning on line 7 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:7: Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime::iree_ep_rt {

Expand Down Expand Up @@ -57,10 +58,18 @@
}

iree_status_t Session::Initialize() {
return iree_runtime_session_create_with_device(
iree_status_t res_status = iree_runtime_session_create_with_device(
instance->instance, &session_options, instance->device,
iree_runtime_instance_host_allocator(instance->instance),
&session);
iree_vm_module_t* custom_module = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_custom_module_async_create(
iree_runtime_instance_vm_instance(instance->instance), instance->device,
host_allocator, &custom_module));
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
iree_vm_module_release(custom_module);
return res_status;
}

iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
Expand Down Expand Up @@ -147,6 +156,13 @@
common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
// better but this gets points for simplicity and lets us bootstrap the tests.
iree_vm_list_t* inputs = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &inputs));
iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &outputs));
Ort::KernelContext context(ort_context_c);
SynchronousCall call(session);
ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name)));
Expand All @@ -161,59 +177,93 @@

// Process inputs. We could be smarter about this in a lot of ways, including carrying
// more state from compilation so we are doing less munging here.
for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(i);
ORT_ENFORCE(input_tensor.IsTensor());

// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
// is useful for anything.
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);

const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
static_cast<int>(tensor_type.GetElementType()));
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);

// Yes, that's right, returned as an std::vector by value :(
// And of a different type than we expect.
std::vector<int64_t> shape = tensor_type.GetShape();
dims.resize(shape.size());
std::copy(shape.begin(), shape.end(), dims.begin());

// No convenient way to get the byte size of the raw data.
size_t element_count = tensor_type.GetElementCount();
const void* raw_data = input_tensor.GetTensorRawData();

HalBufferView arg;
iree_hal_buffer_params_t buffer_params;
memset(&buffer_params, 0, sizeof(buffer_params));
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
device, device_allocator,
// Shape rank and dimensions:
dims.size(), dims.data(),
// Element type:
element_type,
// Encoding type:
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
buffer_params,
// The actual heap buffer to wrap or clone and its allocator:
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
// Buffer view + storage are returned and owned by the caller:
&arg.bv)));

// Add it to the call.
iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
}

std::cout << "input count: " << context.GetInputCount() << "\n";
// for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(0);
ORT_ENFORCE(input_tensor.IsTensor());

// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
// is useful for anything.
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);

const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
static_cast<int>(tensor_type.GetElementType()));
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);

// Yes, that's right, returned as an std::vector by value :(
// And of a different type than we expect.
std::vector<int64_t> shape = tensor_type.GetShape();

Check warning on line 200 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:200: Add #include <vector> for vector<> [build/include_what_you_use] [4]
dims.resize(shape.size());
std::copy(shape.begin(), shape.end(), dims.begin());

// No convenient way to get the byte size of the raw data.
size_t element_count = tensor_type.GetElementCount();
const void* raw_data = input_tensor.GetTensorRawData();

HalBufferView arg;
iree_hal_buffer_params_t buffer_params;
memset(&buffer_params, 0, sizeof(buffer_params));
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
device, device_allocator,
// Shape rank and dimensions:
dims.size(), dims.data(),
// Element type:
element_type,
// Encoding type:
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
buffer_params,
// The actual heap buffer to wrap or clone and its allocator:
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
// Buffer view + storage are returned and owned by the caller:
&arg.bv)));

iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref));

iree_hal_semaphore_t* semaphore = NULL;
IREE_CHECK_OK(iree_hal_semaphore_create(
device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
iree_hal_fence_t* fence_t1 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1));
iree_hal_fence_t* fence_t2 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2));
iree_hal_semaphore_release(semaphore);
std::cout << "\n semaphore released";
iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1);
std::cout << "\n semaphore released1";
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref));
std::cout << "\n semaphore released2";
iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2);
std::cout << "\n semaphore released3";
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref));
std::cout << "\n semaphore released4";
IREE_CHECK_OK(iree_hal_fence_signal(fence_t1));
std::cout << "\n T=1 reached";
// Add it to the call.
iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name);
IREE_CHECK_OK(
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));
// We could go do other things now while the async work progresses. Here we
// just immediately wait.
IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout()));
std::cout << "\n T=2 reached";
// iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
// }
// Read back the tensor<?xi32> result:

// Invoke.
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0)));
// ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0)));

// Marshal the outputs.
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
Expand All @@ -222,37 +272,44 @@
// convention, which allows passing in slabs of result buffers. Further, that would
// run the host-side computation (which would compute output metadata) inline.
// For static cases, we could also side-load the shape from the compile time.
std::vector<int64_t> shape;
for (size_t i = 0; i < context.GetOutputCount(); ++i) {
HalBufferView ret;
ORT_RETURN_IF_ERROR(HandleIREEStatus(
iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
shape.resize(ret_rank);
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
// std::vector<int64_t> shape;
std::cout << "output count: " << context.GetOutputCount() << "\n";
// for (size_t i = 0; i < context.GetOutputCount(); ++i) {
HalBufferView ret;
ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(
// iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
shape.clear();
shape.resize(ret_rank);
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());

Check warning on line 286 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for copy [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:286: Add #include <algorithm> for copy [build/include_what_you_use] [4]
auto output_tensor = context.GetOutput(0, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a

Check warning on line 291 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:291: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);

Check warning on line 294 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:294: Add #include <string> for string [build/include_what_you_use] [4]
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
}
// }

return common::Status::OK();
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_fence_release(fence_t1);
iree_hal_fence_release(fence_t2);
return common::Status::OK();
}

} // namespace onnxruntime::iree_ep_rt
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/iree/iree_ep_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include "core/common/common.h"
#include "core/session/onnxruntime_c_api.h"
#include "iree/modules/hal/types.h"
#include "iree/runtime/api.h"

#include "module.h"

Check warning on line 11 in onnxruntime/core/providers/iree/iree_ep_runtime.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.h:11: Include the directory when naming header files [build/include_subdir] [4]

#include <filesystem>

namespace fs = std::filesystem;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
}
std::string extra_flag_2 = "--iree-execution-model=async-external";
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag_2.c_str()));

ORT_RETURN_IF_ERROR(compiler.Initialize());
std::string module_name = "ort";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include "core/providers/dml/dml_session_options_config_keys.h"
#endif

#ifdef USE_IREE
#include "core/providers/iree/iree_provider_factory.h"
#endif

#ifdef _WIN32
#define strdup _strdup
#endif
Expand Down
Loading