Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mixed and input parameters to MIGraphX EP #38

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/rocm/fused_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ class FusedConv : public onnxruntime::rocm::Conv<T, false> {
auto ret = miopenCompileFusionPlan(handle, fusion->plan);
if (miopenStatusSuccess == ret) {
fusion->compiled_on.insert(handle);
}
else {
} else {
return ret;
}
return miopenStatusSuccess;
Expand Down
80 changes: 64 additions & 16 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1115,39 +1115,51 @@
}

std::vector<std::string> input_names, output_names;
no_input_shape = get_input_output_names(graph_body_viewer, input_names, output_names);
no_input_shape = no_input_shape || get_input_output_names(graph_body_viewer, input_names, output_names);

// by parsing the model_proto, create a program corresponding to
// the input fused_node
migraphx::program prog;

if (!no_input_shape) {
LOGS_DEFAULT(VERBOSE) << "No Input shapes detected quantizing model" << std::endl;
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
if (fp16_enable_) {
migraphx::quantize_fp16(prog);
}

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
if (int8_enable_ && int8_calibration_cache_available_) {
LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl;
migraphx::quantize_int8_options quant_opts;
migraphx::program_parameters quant_params;

auto param_shapes = prog.get_parameter_shapes();

for (auto&& name : param_shapes.names()) {
auto dynamic_range_i = dynamic_range_map.find(name);
if (dynamic_range_i != dynamic_range_map.end()) {
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
}
// Add all calibration data read in from int8 table
for (auto& [cal_key, cal_val] : dynamic_range_map) {
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
}

quant_opts.add_calibration_data(quant_params);

// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");
groenenboomj marked this conversation as resolved.
Show resolved Hide resolved

// perform static quantization on the programs
migraphx::quantize_int8(prog, t_, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl;
}

if (fp16_enable_) {
LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl;
migraphx::quantize_fp16(prog);
LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl;
}

migraphx::compile_options co;
co.set_fast_math(false);
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
prog.compile(t_, co);
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
auto prog_output_shapes = prog.get_output_shapes();
for (std::size_t i = 0; i < output_names.size(); ++i) {
auto out_len = prog_output_shapes[i].lengths();
Expand Down Expand Up @@ -1197,6 +1209,7 @@
bool input_shape_match = true;
migraphx::program_parameter_shapes param_shapes;
if (no_input_shape) {
LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl;
for (auto& it : map_input_name_index) {
auto& name = it.first;
auto& index = it.second;
Expand All @@ -1208,6 +1221,7 @@
input_shape_match = false;
}
} else {
LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl;
param_shapes = prog.get_parameter_shapes();
auto prog_output_shapes = prog.get_output_shapes();

Expand Down Expand Up @@ -1241,33 +1255,64 @@
// input shapes are different, needs to re-parse onnx and
// re-compile the program
if (!input_shape_match) {
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
if (fp16_enable) {
migraphx::quantize_fp16(prog);
}

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
if (int8_enable && int8_calibration_cache_available) {
LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl;
migraphx::quantize_int8_options quant_opts;
migraphx::program_parameters quant_params;

auto param_shapes = prog.get_parameter_shapes();

// Add input parameter data and the values they're set to
for (auto&& name : param_shapes.names()) {
auto dynamic_range_i = map_dynamic_range.find(name);
if (dynamic_range_i != map_dynamic_range.end()) {
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
if (map_input_name_index.count(name) > 0) {
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shape = tensor_info.GetShape();
const auto tensor_type = tensor_info.GetElementType();

migraphx_shape_datatype_t mgx_type;
getMIGraphXType(tensor_type, mgx_type);
auto mgx_s = param_shapes[name];

if (mgx_type != mgx_s.type()) {
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
}
quant_params.add(name, migraphx::argument(param_shapes[name], const_cast<void*>(input_tensor.GetTensorRawData())));

Check warning on line 1284 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1284: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
}

// Add all calibration data read in from int8 table
for (auto& [cal_key, cal_val] : map_dynamic_range) {
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));

Check warning on line 1291 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1291: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 1291 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1291: Add #include <utility> for move [build/include_what_you_use] [4]
}
quant_opts.add_calibration_data(quant_params);

// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");

// perform static quantization on the programs
migraphx::quantize_int8(prog, t, quant_opts);
LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl;
}

if (fp16_enable) {
LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl;
migraphx::quantize_fp16(prog);
LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl;
}

LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
migraphx::compile_options co;
co.set_fast_math(false);
prog.compile(t, co);

LOGS_DEFAULT(INFO) << "Model Compile: Completed" << std::endl;
mgx_state->prog = prog;
param_shapes = prog.get_parameter_shapes();
no_input_shape = false;
Expand All @@ -1279,6 +1324,7 @@
if (param_shapes.size() > 0) {
for (auto&& name : param_shapes.names()) {
if (map_input_name_index.count(name) > 0) {
LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl;
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shape = tensor_info.GetShape();
Expand All @@ -1291,6 +1337,8 @@
if (mgx_type != mgx_s.type()) {
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
}

LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl;
m.add(name, migraphx::argument(param_shapes[name],
const_cast<void*>(input_tensor.GetTensorRawData())));
}
Expand Down
Loading