From 168699a170d903a41527dfd9e52a138027ff1dca Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 4 Dec 2023 22:08:37 +0000 Subject: [PATCH] [XRT] Add device parameters to driver/device creation APIs This commit defines `iree_hal_metal_device_params_t` for controlling major XRT device behavior. Right now we expose arena block size and command dispatch type. --- runtime/src/iree-amd-aie/driver/xrt/api.h | 24 +++++++++- .../driver/xrt/registration/driver_module.c | 7 ++- .../src/iree-amd-aie/driver/xrt/xrt_device.cc | 44 ++++++++++++------- .../src/iree-amd-aie/driver/xrt/xrt_device.h | 19 +++++--- .../src/iree-amd-aie/driver/xrt/xrt_driver.cc | 44 ++++++++++++++----- 5 files changed, 103 insertions(+), 35 deletions(-) diff --git a/runtime/src/iree-amd-aie/driver/xrt/api.h b/runtime/src/iree-amd-aie/driver/xrt/api.h index 4f9d1d50c..33f48bb1a 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/api.h +++ b/runtime/src/iree-amd-aie/driver/xrt/api.h @@ -16,6 +16,25 @@ extern "C" { #endif // __cplusplus +//===----------------------------------------------------------------------===// +// iree_hal_xrt_device_params_t +//===----------------------------------------------------------------------===// + + +// Parameters configuring an iree_hal_xrt_device_t. +// Must be initialized with iree_hal_xrt_device_params_initialize prior to +// use. +typedef struct iree_hal_xrt_device_params_t { + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; +} iree_hal_xrt_device_params_t; + +// Initializes |out_params| to default values. +void iree_hal_xrt_device_params_initialize( + iree_hal_xrt_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_xrt_driver_t //===----------------------------------------------------------------------===// @@ -25,8 +44,9 @@ extern "C" { // // |out_driver| must be released by the caller (see iree_hal_driver_release). IREE_API_EXPORT iree_status_t iree_hal_xrt_driver_create( - iree_string_view_t identifier, iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver); + iree_string_view_t identifier, + const iree_hal_xrt_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree-amd-aie/driver/xrt/registration/driver_module.c b/runtime/src/iree-amd-aie/driver/xrt/registration/driver_module.c index 4ed952dfd..724666309 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/registration/driver_module.c +++ b/runtime/src/iree-amd-aie/driver/xrt/registration/driver_module.c @@ -45,8 +45,11 @@ static iree_status_t iree_hal_xrt_driver_factory_try_create( IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = - iree_hal_xrt_driver_create(driver_name, host_allocator, out_driver); + iree_hal_xrt_device_params_t device_params; + iree_hal_xrt_device_params_initialize(&device_params); + + iree_status_t status = iree_hal_xrt_driver_create(driver_name, &device_params, + host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc index f5293dd91..5767784cd 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc @@ -24,9 +24,7 @@ typedef struct iree_hal_xrt_device_t { iree_string_view_t identifier; - // Original driver that owns this device. - iree_hal_driver_t* driver; - + iree_hal_xrt_device_params_t params; iree_allocator_t host_allocator; iree_hal_allocator_t* device_allocator; @@ -43,9 +41,28 @@ static iree_hal_xrt_device_t* iree_hal_xrt_device_cast( return (iree_hal_xrt_device_t*)base_value; } +static const iree_hal_xrt_device_t* iree_hal_xrt_device_const_cast( + const iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_device_vtable); + return (const iree_hal_xrt_device_t*)base_value; +} + +void iree_hal_xrt_device_params_initialize( + iree_hal_xrt_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32 * 1024; +} + +const iree_hal_xrt_device_params_t* iree_hal_xrt_device_params( + const iree_hal_device_t* base_device) { + const iree_hal_xrt_device_t* device = + iree_hal_xrt_device_const_cast(base_device); + return &device->params; +} + static iree_status_t iree_hal_xrt_device_create_internal( - iree_hal_driver_t* driver, iree_string_view_t identifier, - xrt::device xrt_device, iree_allocator_t host_allocator, + iree_string_view_t identifier, xrt::device xrt_device, + const iree_hal_xrt_device_params_t* params, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_xrt_device_t* device = NULL; @@ -62,10 +79,10 @@ static iree_status_t iree_hal_xrt_device_create_internal( iree_string_view_append_to_buffer( identifier, &device->identifier, (char*)device + iree_sizeof_struct(*device)); - device->driver = driver; - iree_hal_driver_retain(device->driver); + device->host_allocator = host_allocator; device->device = xrt_device; + device->params = *params; *out_device = (iree_hal_device_t*)device; } else { @@ -74,16 +91,15 @@ static iree_status_t iree_hal_xrt_device_create_internal( return status; } -iree_status_t iree_hal_xrt_device_create(iree_hal_driver_t* driver, - iree_string_view_t identifier, - xrt::device device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { +iree_status_t iree_hal_xrt_device_create( + iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params, + xrt::device device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(out_device); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_xrt_device_create_internal( - driver, identifier, device, host_allocator, out_device); + identifier, device, params, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; @@ -95,8 +111,6 @@ static void iree_hal_xrt_device_destroy(iree_hal_device_t* base_device) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_allocator_release(device->device_allocator); - iree_hal_driver_release(device->driver); - iree_allocator_free(host_allocator, device); IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h index bccaf08b8..50540a14f 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h @@ -7,6 +7,7 @@ #ifndef IREE_AMD_AIE_DRIVER_XRT_XRT_DEVICE_H_ #define IREE_AMD_AIE_DRIVER_XRT_XRT_DEVICE_H_ +#include "iree-amd-aie/driver/xrt/api.h" #include "iree/base/api.h" #include "iree/hal/api.h" #include "xrt/xrt_device.h" @@ -15,12 +16,18 @@ extern "C" { #endif // __cplusplus -// Creates a XRT device. -iree_status_t iree_hal_xrt_device_create(iree_hal_driver_t* driver, - iree_string_view_t identifier, - xrt::device device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device); +// Creates a XRT device by wrapping |device| from the given |driver| with the +// specific |params|. +// +// |out_device| must be released by the caller (see iree_hal_device_release). +iree_status_t iree_hal_xrt_device_create( + iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params, + xrt::device device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); + +// Returns the parameters used for creating the device. +const iree_hal_xrt_device_params_t* iree_hal_xrt_device_params( + const iree_hal_device_t* device); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc b/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc index c15fdd776..2ff9bd138 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-amd-aie/driver/xrt/api.h" #include "iree-amd-aie/driver/xrt/xrt_device.h" #include "iree/base/api.h" #include "iree/base/target_platform.h" @@ -39,6 +38,9 @@ typedef struct iree_hal_xrt_driver_t { // Identifier used for the driver in the IREE driver registry.. iree_string_view_t identifier; + // Parameters used to control device behavior. + iree_hal_xrt_device_params_t device_params; + xrt::device device; } iree_hal_xrt_driver_t; @@ -53,9 +55,25 @@ static iree_hal_xrt_driver_t* iree_hal_xrt_driver_cast( return (iree_hal_xrt_driver_t*)base_value; } +static const iree_hal_xrt_driver_t* iree_hal_xrt_driver_const_cast( + const iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_driver_vtable); + return (const iree_hal_xrt_driver_t*)base_value; +} + +static iree_status_t iree_hal_xrt_device_check_params( + const iree_hal_xrt_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + return iree_ok_status(); +} + iree_status_t iree_hal_xrt_driver_create_internal( - iree_string_view_t identifier, iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver) { + iree_string_view_t identifier, + const iree_hal_xrt_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_xrt_driver_t* driver = NULL; iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; IREE_RETURN_IF_ERROR( @@ -66,6 +84,7 @@ iree_status_t iree_hal_xrt_driver_create_internal( iree_string_view_append_to_buffer( identifier, &driver->identifier, (char*)driver + iree_sizeof_struct(*driver)); + driver->device_params = *device_params; int device_count = xrt::system::enumerate_devices(); if (IREE_UNLIKELY(device_count == 0)) { @@ -79,13 +98,16 @@ iree_status_t iree_hal_xrt_driver_create_internal( } IREE_API_EXPORT iree_status_t iree_hal_xrt_driver_create( - iree_string_view_t identifier, iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver) { + iree_string_view_t identifier, + const iree_hal_xrt_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_xrt_device_check_params(device_params)); iree_status_t status = iree_hal_xrt_driver_create_internal( - identifier, host_allocator, out_driver); + identifier, device_params, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -192,8 +214,9 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_id( iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); iree_string_view_t device_name = iree_make_cstring_view("xrt"); - iree_status_t status = iree_hal_xrt_device_create( - base_driver, device_name, driver->device, host_allocator, out_device); + iree_status_t status = + iree_hal_xrt_device_create(device_name, &driver->device_params, + driver->device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; @@ -208,8 +231,9 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_path( iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); iree_string_view_t device_name = iree_make_cstring_view("xrt"); - iree_status_t status = iree_hal_xrt_device_create( - base_driver, device_name, driver->device, host_allocator, out_device); + iree_status_t status = + iree_hal_xrt_device_create(device_name, &driver->device_params, + driver->device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status;