From 16f08f8de8ca6db876626b25b01fb89a95438eae Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Sun, 27 Oct 2024 23:06:03 -0500 Subject: [PATCH] [IREE-EP] Integrate iree async module in the IREE-EP Signed-Off-by: Gaurav Shukla --- .../core/providers/iree/iree_ep_runtime.cc | 215 +++++++++++------- .../core/providers/iree/iree_ep_runtime.h | 3 + .../providers/iree/iree_execution_provider.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 4 + 4 files changed, 145 insertions(+), 79 deletions(-) diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.cc b/onnxruntime/core/providers/iree/iree_ep_runtime.cc index 086ef9962465a..90d69bb02a08f 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.cc +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.cc @@ -4,6 +4,7 @@ #include "core/providers/iree/iree_ep_runtime.h" #include "core/session/onnxruntime_cxx_api.h" +#include namespace onnxruntime::iree_ep_rt { @@ -57,10 +58,18 @@ Session::~Session() { } iree_status_t Session::Initialize() { - return iree_runtime_session_create_with_device( + iree_status_t res_status = iree_runtime_session_create_with_device( instance->instance, &session_options, instance->device, iree_runtime_instance_host_allocator(instance->instance), &session); + iree_vm_module_t* custom_module = NULL; + iree_allocator_t host_allocator = iree_allocator_system(); + IREE_CHECK_OK(iree_custom_module_async_create( + iree_runtime_instance_vm_instance(instance->instance), instance->device, + host_allocator, &custom_module)); + IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module)); + iree_vm_module_release(custom_module); + return res_status; } iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function dispose_callback) { @@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) { common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) { // TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do // better but this gets points for simplicity and lets us bootstrap the tests. + iree_vm_list_t* inputs = NULL; + iree_allocator_t host_allocator = iree_allocator_system(); + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1, + host_allocator, &inputs)); + iree_vm_list_t* outputs = NULL; + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1, + host_allocator, &outputs)); Ort::KernelContext context(ort_context_c); SynchronousCall call(session); ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name))); @@ -161,59 +177,93 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, // Process inputs. We could be smarter about this in a lot of ways, including carrying // more state from compilation so we are doing less munging here. - for (size_t i = 0; i < context.GetInputCount(); ++i) { - auto input_tensor = context.GetInput(i); - ORT_ENFORCE(input_tensor.IsTensor()); - - // The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that - // is useful for anything. - auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType(); - ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU); - - const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo(); - auto element_type = ConvertOrtElementType(tensor_type.GetElementType()); - ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ", - static_cast(tensor_type.GetElementType())); - ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type)); - size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type); - - // Yes, that's right, returned as an std::vector by value :( - // And of a different type than we expect. - std::vector shape = tensor_type.GetShape(); - dims.resize(shape.size()); - std::copy(shape.begin(), shape.end(), dims.begin()); - - // No convenient way to get the byte size of the raw data. - size_t element_count = tensor_type.GetElementCount(); - const void* raw_data = input_tensor.GetTensorRawData(); - - HalBufferView arg; - iree_hal_buffer_params_t buffer_params; - memset(&buffer_params, 0, sizeof(buffer_params)); - buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; - buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL; - buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy( - device, device_allocator, - // Shape rank and dimensions: - dims.size(), dims.data(), - // Element type: - element_type, - // Encoding type: - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - buffer_params, - // The actual heap buffer to wrap or clone and its allocator: - iree_make_const_byte_span(raw_data, element_count * element_size_bytes), - // Buffer view + storage are returned and owned by the caller: - &arg.bv))); - - // Add it to the call. - iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv); - ORT_RETURN_IF_ERROR(HandleIREEStatus(status)); - } + + std::cout << "input count: " << context.GetInputCount() << "\n"; + // for (size_t i = 0; i < context.GetInputCount(); ++i) { + auto input_tensor = context.GetInput(0); + ORT_ENFORCE(input_tensor.IsTensor()); + + // The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that + // is useful for anything. + auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType(); + ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU); + + const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo(); + auto element_type = ConvertOrtElementType(tensor_type.GetElementType()); + ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ", + static_cast(tensor_type.GetElementType())); + ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type)); + size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type); + + // Yes, that's right, returned as an std::vector by value :( + // And of a different type than we expect. + std::vector shape = tensor_type.GetShape(); + dims.resize(shape.size()); + std::copy(shape.begin(), shape.end(), dims.begin()); + + // No convenient way to get the byte size of the raw data. + size_t element_count = tensor_type.GetElementCount(); + const void* raw_data = input_tensor.GetTensorRawData(); + + HalBufferView arg; + iree_hal_buffer_params_t buffer_params; + memset(&buffer_params, 0, sizeof(buffer_params)); + buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; + buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL; + buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; + ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy( + device, device_allocator, + // Shape rank and dimensions: + dims.size(), dims.data(), + // Element type: + element_type, + // Encoding type: + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + buffer_params, + // The actual heap buffer to wrap or clone and its allocator: + iree_make_const_byte_span(raw_data, element_count * element_size_bytes), + // Buffer view + storage are returned and owned by the caller: + &arg.bv))); + + iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv); + IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref)); + + iree_hal_semaphore_t* semaphore = NULL; + IREE_CHECK_OK(iree_hal_semaphore_create( + device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore)); + iree_hal_fence_t* fence_t1 = NULL; + IREE_CHECK_OK( + iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1)); + iree_hal_fence_t* fence_t2 = NULL; + IREE_CHECK_OK( + iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2)); + iree_hal_semaphore_release(semaphore); + std::cout << "\n semaphore released"; + iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1); + std::cout << "\n semaphore released1"; + IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref)); + std::cout << "\n semaphore released2"; + iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2); + std::cout << "\n semaphore released3"; + IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref)); + std::cout << "\n semaphore released4"; + IREE_CHECK_OK(iree_hal_fence_signal(fence_t1)); + std::cout << "\n T=1 reached"; + // Add it to the call. + iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name); + IREE_CHECK_OK( + iree_runtime_session_call_by_name(session, entry_point, inputs, outputs)); + // We could go do other things now while the async work progresses. Here we + // just immediately wait. + IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout())); + std::cout << "\n T=2 reached"; + // iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv); + // ORT_RETURN_IF_ERROR(HandleIREEStatus(status)); + // } + // Read back the tensor result: // Invoke. - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0))); + // ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0))); // Marshal the outputs. // TODO: Accessing the ORT output requires the shape and then we could get zero copy @@ -222,37 +272,44 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, // convention, which allows passing in slabs of result buffers. Further, that would // run the host-side computation (which would compute output metadata) inline. // For static cases, we could also side-load the shape from the compile time. - std::vector shape; - for (size_t i = 0; i < context.GetOutputCount(); ++i) { - HalBufferView ret; - ORT_RETURN_IF_ERROR(HandleIREEStatus( - iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv))); - size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv); - const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv); - shape.resize(ret_rank); - std::copy(ret_dims, ret_dims + ret_rank, shape.begin()); - auto output_tensor = context.GetOutput(i, shape.data(), shape.size()); - ORT_ENFORCE(output_tensor.IsTensor()); - - iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv); - // TODO: Synchronous mapping read, like everything in this function, is not a - // great idea. It isn't supported on all device types and will need a scrub. - iree_string_view_t device_val = iree_hal_device_id(device); - auto device_str = std::string(device_val.data, device_val.size); - if (device_str == "hip") { - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( - iree_runtime_session_device(session), - ret_buffer, 0, output_tensor.GetTensorMutableRawData(), - iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()))); - return common::Status::OK(); - } + // std::vector shape; + std::cout << "output count: " << context.GetOutputCount() << "\n"; + // for (size_t i = 0; i < context.GetOutputCount(); ++i) { + HalBufferView ret; + ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0); + // ORT_RETURN_IF_ERROR(HandleIREEStatus( + // iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv))); + size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv); + const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv); + shape.clear(); + shape.resize(ret_rank); + std::copy(ret_dims, ret_dims + ret_rank, shape.begin()); + auto output_tensor = context.GetOutput(0, shape.data(), shape.size()); + ORT_ENFORCE(output_tensor.IsTensor()); + + iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv); + // TODO: Synchronous mapping read, like everything in this function, is not a + // great idea. It isn't supported on all device types and will need a scrub. + iree_string_view_t device_val = iree_hal_device_id(device); + auto device_str = std::string(device_val.data, device_val.size); + if (device_str == "hip") { + ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( + iree_runtime_session_device(session), + ret_buffer, 0, output_tensor.GetTensorMutableRawData(), + iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()))); + return common::Status::OK(); + } ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0, output_tensor.GetTensorMutableRawData(), iree_hal_buffer_view_byte_length(ret.bv)))); - } + // } - return common::Status::OK(); + iree_vm_list_release(inputs); + iree_vm_list_release(outputs); + iree_hal_fence_release(fence_t1); + iree_hal_fence_release(fence_t2); + return common::Status::OK(); } } // namespace onnxruntime::iree_ep_rt diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.h b/onnxruntime/core/providers/iree/iree_ep_runtime.h index 3a1f5cb07d579..9e71f4d42e3db 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.h +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.h @@ -5,8 +5,11 @@ #include "core/common/common.h" #include "core/session/onnxruntime_c_api.h" +#include "iree/modules/hal/types.h" #include "iree/runtime/api.h" +#include "module.h" + #include namespace fs = std::filesystem; diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index d504561707e60..bad069632941a 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector