Skip to content

Commit

Permalink
[XRT] Add device parameters to driver/device creation APIs
Browse files Browse the repository at this point in the history
This commit defines `iree_hal_metal_device_params_t` for
controlling major XRT device behavior. Right now we expose
arena block size and command dispatch type.
  • Loading branch information
nirvedhmeshram committed Dec 7, 2023
1 parent 2b77d09 commit 168699a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 35 deletions.
24 changes: 22 additions & 2 deletions runtime/src/iree-amd-aie/driver/xrt/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@
extern "C" {
#endif // __cplusplus

//===----------------------------------------------------------------------===//
// iree_hal_xrt_device_params_t
//===----------------------------------------------------------------------===//


// Parameters configuring an iree_hal_xrt_device_t.
// Must be initialized with iree_hal_xrt_device_params_initialize prior to
// use.
typedef struct iree_hal_xrt_device_params_t {
// Total size of each block in the device shared block pool.
// Larger sizes will lower overhead and ensure the heap isn't hit for
// transient allocations while also increasing memory consumption.
iree_host_size_t arena_block_size;
} iree_hal_xrt_device_params_t;

// Initializes |out_params| to default values.
void iree_hal_xrt_device_params_initialize(
iree_hal_xrt_device_params_t* out_params);

//===----------------------------------------------------------------------===//
// iree_hal_xrt_driver_t
//===----------------------------------------------------------------------===//
Expand All @@ -25,8 +44,9 @@ extern "C" {
//
// |out_driver| must be released by the caller (see iree_hal_driver_release).
IREE_API_EXPORT iree_status_t iree_hal_xrt_driver_create(
iree_string_view_t identifier, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver);
iree_string_view_t identifier,
const iree_hal_xrt_device_params_t* device_params,
iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);

#ifdef __cplusplus
} // extern "C"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ static iree_status_t iree_hal_xrt_driver_factory_try_create(

IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status =
iree_hal_xrt_driver_create(driver_name, host_allocator, out_driver);
iree_hal_xrt_device_params_t device_params;
iree_hal_xrt_device_params_initialize(&device_params);

iree_status_t status = iree_hal_xrt_driver_create(driver_name, &device_params,
host_allocator, out_driver);

IREE_TRACE_ZONE_END(z0);

Expand Down
44 changes: 29 additions & 15 deletions runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ typedef struct iree_hal_xrt_device_t {

iree_string_view_t identifier;

// Original driver that owns this device.
iree_hal_driver_t* driver;

iree_hal_xrt_device_params_t params;
iree_allocator_t host_allocator;
iree_hal_allocator_t* device_allocator;

Expand All @@ -43,9 +41,28 @@ static iree_hal_xrt_device_t* iree_hal_xrt_device_cast(
return (iree_hal_xrt_device_t*)base_value;
}

static const iree_hal_xrt_device_t* iree_hal_xrt_device_const_cast(
const iree_hal_device_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_device_vtable);
return (const iree_hal_xrt_device_t*)base_value;
}

void iree_hal_xrt_device_params_initialize(
iree_hal_xrt_device_params_t* out_params) {
memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
}

const iree_hal_xrt_device_params_t* iree_hal_xrt_device_params(
const iree_hal_device_t* base_device) {
const iree_hal_xrt_device_t* device =
iree_hal_xrt_device_const_cast(base_device);
return &device->params;
}

static iree_status_t iree_hal_xrt_device_create_internal(
iree_hal_driver_t* driver, iree_string_view_t identifier,
xrt::device xrt_device, iree_allocator_t host_allocator,
iree_string_view_t identifier, xrt::device xrt_device,
const iree_hal_xrt_device_params_t* params, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
iree_hal_xrt_device_t* device = NULL;

Expand All @@ -62,10 +79,10 @@ static iree_status_t iree_hal_xrt_device_create_internal(
iree_string_view_append_to_buffer(
identifier, &device->identifier,
(char*)device + iree_sizeof_struct(*device));
device->driver = driver;
iree_hal_driver_retain(device->driver);

device->host_allocator = host_allocator;
device->device = xrt_device;
device->params = *params;

*out_device = (iree_hal_device_t*)device;
} else {
Expand All @@ -74,16 +91,15 @@ static iree_status_t iree_hal_xrt_device_create_internal(
return status;
}

iree_status_t iree_hal_xrt_device_create(iree_hal_driver_t* driver,
iree_string_view_t identifier,
xrt::device device,
iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
iree_status_t iree_hal_xrt_device_create(
iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params,
xrt::device device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
IREE_ASSERT_ARGUMENT(out_device);
IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status = iree_hal_xrt_device_create_internal(
driver, identifier, device, host_allocator, out_device);
identifier, device, params, host_allocator, out_device);

IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -95,8 +111,6 @@ static void iree_hal_xrt_device_destroy(iree_hal_device_t* base_device) {
IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_allocator_release(device->device_allocator);
iree_hal_driver_release(device->driver);

iree_allocator_free(host_allocator, device);

IREE_TRACE_ZONE_END(z0);
Expand Down
19 changes: 13 additions & 6 deletions runtime/src/iree-amd-aie/driver/xrt/xrt_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_AMD_AIE_DRIVER_XRT_XRT_DEVICE_H_
#define IREE_AMD_AIE_DRIVER_XRT_XRT_DEVICE_H_

#include "iree-amd-aie/driver/xrt/api.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "xrt/xrt_device.h"
Expand All @@ -15,12 +16,18 @@
extern "C" {
#endif // __cplusplus

// Creates a XRT device.
iree_status_t iree_hal_xrt_device_create(iree_hal_driver_t* driver,
iree_string_view_t identifier,
xrt::device device,
iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
// Creates a XRT device by wrapping |device| from the given |driver| with the
// specific |params|.
//
// |out_device| must be released by the caller (see iree_hal_device_release).
iree_status_t iree_hal_xrt_device_create(
iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params,
xrt::device device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device);

// Returns the parameters used for creating the device.
const iree_hal_xrt_device_params_t* iree_hal_xrt_device_params(
const iree_hal_device_t* device);

#ifdef __cplusplus
} // extern "C"
Expand Down
44 changes: 34 additions & 10 deletions runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/driver/xrt/api.h"
#include "iree-amd-aie/driver/xrt/xrt_device.h"
#include "iree/base/api.h"
#include "iree/base/target_platform.h"
Expand Down Expand Up @@ -39,6 +38,9 @@ typedef struct iree_hal_xrt_driver_t {
// Identifier used for the driver in the IREE driver registry..
iree_string_view_t identifier;

// Parameters used to control device behavior.
iree_hal_xrt_device_params_t device_params;

xrt::device device;

} iree_hal_xrt_driver_t;
Expand All @@ -53,9 +55,25 @@ static iree_hal_xrt_driver_t* iree_hal_xrt_driver_cast(
return (iree_hal_xrt_driver_t*)base_value;
}

static const iree_hal_xrt_driver_t* iree_hal_xrt_driver_const_cast(
const iree_hal_driver_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_driver_vtable);
return (const iree_hal_xrt_driver_t*)base_value;
}

static iree_status_t iree_hal_xrt_device_check_params(
const iree_hal_xrt_device_params_t* params) {
if (params->arena_block_size < 4096) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"arena block size too small (< 4096 bytes)");
}
return iree_ok_status();
}

iree_status_t iree_hal_xrt_driver_create_internal(
iree_string_view_t identifier, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
iree_string_view_t identifier,
const iree_hal_xrt_device_params_t* device_params,
iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
iree_hal_xrt_driver_t* driver = NULL;
iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size;
IREE_RETURN_IF_ERROR(
Expand All @@ -66,6 +84,7 @@ iree_status_t iree_hal_xrt_driver_create_internal(
iree_string_view_append_to_buffer(
identifier, &driver->identifier,
(char*)driver + iree_sizeof_struct(*driver));
driver->device_params = *device_params;

int device_count = xrt::system::enumerate_devices();
if (IREE_UNLIKELY(device_count == 0)) {
Expand All @@ -79,13 +98,16 @@ iree_status_t iree_hal_xrt_driver_create_internal(
}

IREE_API_EXPORT iree_status_t iree_hal_xrt_driver_create(
iree_string_view_t identifier, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
iree_string_view_t identifier,
const iree_hal_xrt_device_params_t* device_params,
iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
IREE_TRACE_ZONE_BEGIN(z0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_xrt_device_check_params(device_params));
iree_status_t status = iree_hal_xrt_driver_create_internal(
identifier, host_allocator, out_driver);
identifier, device_params, host_allocator, out_driver);

IREE_TRACE_ZONE_END(z0);
return status;
Expand Down Expand Up @@ -192,8 +214,9 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_id(
iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver);
iree_string_view_t device_name = iree_make_cstring_view("xrt");

iree_status_t status = iree_hal_xrt_device_create(
base_driver, device_name, driver->device, host_allocator, out_device);
iree_status_t status =
iree_hal_xrt_device_create(device_name, &driver->device_params,
driver->device, host_allocator, out_device);

IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -208,8 +231,9 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_path(
iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver);
iree_string_view_t device_name = iree_make_cstring_view("xrt");

iree_status_t status = iree_hal_xrt_device_create(
base_driver, device_name, driver->device, host_allocator, out_device);
iree_status_t status =
iree_hal_xrt_device_create(device_name, &driver->device_params,
driver->device, host_allocator, out_device);

IREE_TRACE_ZONE_END(z0);
return status;
Expand Down

0 comments on commit 168699a

Please sign in to comment.