Skip to content

Commit

Permalink
Adding ortvalue features support for MGX EP (#81)
Browse files Browse the repository at this point in the history
* Adding ourtvalue support for MGX EP

---------

authored-by: Uros Petkovic <[email protected]>
  • Loading branch information
urpetkov-amd authored and Ted Themistokleous committed Jan 17, 2025
1 parent 2ab51c7 commit 577be22
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 28 deletions.
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,21 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false

/** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
* Defaults to SIZE_MAX.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
size_t migraphx_mem_limit;

/** \brief Strategy used to grow the memory arena
* 0 = kNextPowerOfTwo<br>
* 1 = kSameAsRequested<br>
* Defaults to 0.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
int migraphx_arena_extend_strategy;

} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
45 changes: 42 additions & 3 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
#include "core/common/safeint.h"
#include "core/common/logging/severity.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"
#include "migraphx_execution_provider_utils.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
#include "migraphx_inc.h"
#include <hip/hip_version.h>
#include "migraphx_call.h"

#include "migraphx_stream_handle.h"

Expand Down Expand Up @@ -211,12 +212,50 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
}

AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id,
size_t migx_mem_limit,
ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo
external_allocator_info,
const OrtArenaCfg* default_memory_arena_cfg) {
if (external_allocator_info.UseExternalAllocator()) {
AllocatorCreationInfo default_memory_info(
[external_allocator_info](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXExternalAllocator>(id, HIP,
external_allocator_info.alloc,
external_allocator_info.free,
external_allocator_info.empty_cache);
},
device_id,
false);

return CreateAllocator(default_memory_info);
} else {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXAllocator>(id, HIP);
},
device_id,
true,
{default_memory_arena_cfg ? *default_memory_arena_cfg
: OrtArenaCfg(migx_mem_limit, static_cast<int>(arena_extend_strategy),
-1, -1, -1, -1L)},
// make it stream aware
true,
// enable cross stream sharing?
false);

// ROCM malloc/free is expensive so always use an arena
return CreateAllocator(default_memory_info);
}
}

std::vector<AllocatorPtr> MIGraphXExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id);
[](OrtDevice::DeviceId device_id) { return std::make_unique<MIGraphXAllocator>(device_id, onnxruntime::CUDA); }, info_.device_id);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
return CreateMIGraphXPinnedAllocator(device_id, onnxruntime::CUDA_PINNED);
return std::make_unique<HIPPinnedAllocator>(device_id, onnxruntime::CUDA_PINNED);
},
0);
return std::vector<AllocatorPtr>{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "core/framework/execution_provider.h"
#include <mutex>
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_inc.h"
#include "core/providers/migraphx/migraphx_call.h"

#include <map>
#include <unordered_map>
Expand Down Expand Up @@ -76,6 +76,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg);

std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"

#include "core/common/make_string.h"
Expand All @@ -10,6 +11,12 @@
#include "migraphx_call.h"

namespace onnxruntime {

const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};

namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
Expand All @@ -22,12 +29,20 @@ constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
constexpr const char* kMemLimit = "migx_mem_limit";
constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy";
constexpr const char* kGpuExternalAlloc = "migx_external_alloc";
constexpr const char* kGpuExternalFree = "migx_external_free";
constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache";

} // namespace provider_option_names
} // namespace migraphx

MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
MIGraphXExecutionProviderInfo info{};
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -42,13 +57,42 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalFree,
[&free](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalEmptyCache,
[&empty_cache](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
empty_cache = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.Parse(options));

MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;

return info;
}

Expand All @@ -59,6 +103,12 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
Expand All @@ -71,6 +121,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,36 @@
#include <string>

#include "core/framework/ortdevice.h"
#include "core/common/hash_combine.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/provider_options.h"
#include "core/session/onnxruntime_c_api.h"

namespace onnxruntime {

// Information needed to construct MIGraphX execution providers.
struct MIGraphXExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
void* free{nullptr};
void* empty_cache{nullptr};

MIGraphXExecutionProviderExternalAllocatorInfo() {
alloc = nullptr;
free = nullptr;
empty_cache = nullptr;
}

MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) {
alloc = a;
free = f;
empty_cache = e;
}

bool UseExternalAllocator() const {
return (alloc != nullptr) && (free != nullptr);
}
};

// Information needed to construct trt execution providers.
struct MIGraphXExecutionProviderInfo {
std::string target_device;
Expand All @@ -25,8 +51,42 @@ struct MIGraphXExecutionProviderInfo {
std::string load_model_file{"./compiled_model.mxr"};
bool exhaustive_tune{false};

size_t mem_limit{std::numeric_limits<size_t>::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

OrtArenaCfg* default_memory_arena_cfg{nullptr};
MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info);
};
} // namespace onnxruntime

template <>
struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> {
size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const {
size_t value{0xbc9f1d34}; // seed

// Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each)
size_t data = static_cast<size_t>(info.device_id) ^
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^
(static_cast<size_t>(info.fp16_enable) << 18) ^
(static_cast<size_t>(info.int8_enable) << 19) ^
(static_cast<size_t>(info.int8_use_native_calibration_table) << 20) ^
(static_cast<size_t>(info.save_compiled_model) << 21) ^
(static_cast<size_t>(info.load_compiled_model) << 22) ^
(static_cast<size_t>(info.exhaustive_tune) << 23);
onnxruntime::HashCombine(data, value);

onnxruntime::HashCombine(info.mem_limit, value);

// Memory pointers
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.alloc), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.free), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache), value);

// The default memory arena cfg is not used in hashing right now.
return value;
}
};
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"
#include "migraphx_provider_factory_creator.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
Expand Down Expand Up @@ -42,6 +43,27 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX {
return std::make_unique<HIPPinnedAllocator>(device_id, name);
}

void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
// hipMemcpy() operates on the default stream
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice));

// To ensure that the copy has completed, invoke a stream sync for the default stream.
// For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated.
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.

HIP_CALL_THROW(hipStreamSynchronize(0));
}

// Used by onnxruntime_pybind_state.cc
void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override {
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed.
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost));
}

std::shared_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {
return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);
}
} g_info;

struct MIGraphX_Provider : Provider {
Expand Down Expand Up @@ -77,6 +99,8 @@ struct MIGraphX_Provider : Provider {
if (options.migraphx_load_model_path != nullptr) {
info.load_model_file = options.migraphx_load_model_path;
}
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(options.migraphx_arena_extend_strategy);
info.mem_limit = options.migraphx_mem_limit;
return std::make_shared<MIGraphXProviderFactory>(info);
}

Expand Down Expand Up @@ -109,6 +133,8 @@ struct MIGraphX_Provider : Provider {
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
migx_options.migraphx_arena_extend_strategy = static_cast<int>(internal_options.arena_extend_strategy);
migx_options.migraphx_mem_limit = internal_options.mem_limit;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ struct MIGraphXExecutionProviderExternalAllocatorInfo;
struct ProviderInfo_MIGraphX {
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0;
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0;
virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0;
virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0;
virtual std::shared_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0;

protected:
~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance
Expand Down
Loading

0 comments on commit 577be22

Please sign in to comment.