Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] support weight quantized (int8) FullyConnected kernel #14137

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/core/OMKernelData.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ struct FullyConnectedParams
int32_t weights_offset;
int32_t output_offset;
int32_t output_multiplier;
const float *weights_scales;
bool is_channel_wise_quant;
int output_shift;
// uint8_t, etc, activation params.
int32_t quantized_activation_min;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ OMStatus FullyConnected(const core::FullyConnectedParams &params, const InputTyp
return Ok;
}

template <>
OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,
const float *input_data,
const core::OMRuntimeShape &filter_shape,
const float *filter_data, const float *bias_data,
const core::OMRuntimeShape &output_shape, float *output_data)
template <typename WeightType>
OMStatus inline FullyConnected(const core::FullyConnectedParams &params, const float *input_data,
const core::OMRuntimeShape &filter_shape,
const WeightType *filter_data, const float *bias_data,
const core::OMRuntimeShape &output_shape, float *output_data)
{
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
Expand All @@ -93,12 +92,24 @@ OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,

for (int b = 0; b < batches; ++b)
{
const float *weight_scale_ptr = params.weights_scales;
for (int out_c = 0; out_c < output_depth; ++out_c)
{
float total = 0.f;
for (int d = 0; d < accum_depth; ++d)
{
total += input_data[b * accum_depth + d] * filter_data[out_c * accum_depth + d];
auto input_value = input_data[b * accum_depth + d];
if (std::is_same<WeightType, float>::value)
{
total += input_value * filter_data[out_c * accum_depth + d];
}
else
{
const float filter_scale = *weight_scale_ptr;
const float filter_value =
static_cast<float>(filter_data[out_c * accum_depth + d]) * filter_scale;
total += input_value * filter_value;
}
}
float bias_value = 0.0f;
if (bias_data)
Expand All @@ -107,6 +118,12 @@ OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,
}
output_data[out_c + output_depth * b] =
std::min(std::max(total + bias_value, output_activation_min), output_activation_max);

if (std::is_same<WeightType, int8_t>::value)
{
if (params.is_channel_wise_quant)
weight_scale_ptr++;
}
}
}
return Ok;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,82 @@ const std::vector<float> reference_output_data = {263.84323, 260.84323, 259.8432

} // namespace fully_connected_float

namespace fully_connected_float_weights_quantized_int8
{

/*
* FullyConnected Kernel:
* Input - float32
* Weight - int8
* Bias - float32
* Out - float32
*
* Input(1, 4) Weight(4, 4) Bias(4)
* \ | /
* \ | /
* FullyConnected
* |
* Output(1, 4)
*/

const unsigned char test_kernel_model_circle[] = {
0x20, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00,
0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00,
0x12, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0xe0, 0x02, 0x00, 0x00,
0xc8, 0x02, 0x00, 0x00, 0x24, 0x02, 0x00, 0x00, 0xc0, 0x01, 0x00, 0x00, 0x90, 0x01, 0x00, 0x00,
0x74, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0xcc, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x5f, 0x6f, 0x70, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65,
0x00, 0x00, 0x00, 0x00, 0x22, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00,
0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x4f, 0x4e, 0x45, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65,
0x00, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00,
0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x7c, 0x01, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00,
0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x18, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x90, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0xe0, 0xfe, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00, 0xc4, 0xfe, 0xff, 0xff,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xff, 0xff, 0xff,
0x34, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x62, 0x69, 0x61, 0x73, 0x00, 0x00, 0x00, 0x00, 0xa6, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x40, 0xc0,
0x00, 0x00, 0x80, 0x40, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09, 0x94, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x00, 0x00,
0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x55, 0x7f, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x00,
0x00, 0x00, 0x00, 0x7f, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x83, 0xc1, 0x3c, 0x04, 0x02, 0x01, 0x3d,
0x85, 0x42, 0x21, 0x3d, 0x06, 0x83, 0x41, 0x3d, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x00, 0x00, 0xf0, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00};

const std::vector<float> input_data = {17.491695, 15.660671, 4.7347794, -15.796822};

const std::vector<float> reference_output_data = {-19.529659, 60.642685, 20.673897, -90.780930};

} // namespace fully_connected_float_weights_quantized_int8

class TestDataFloatFullyConnected : public TestDataFullyConnectedBase<float>
{
public:
Expand All @@ -109,6 +185,20 @@ class TestDataFloatFullyConnected : public TestDataFullyConnectedBase<float>
~TestDataFloatFullyConnected() override = default;
};

class TestDataFloatWQInt8FullyConnected : public TestDataFullyConnectedBase<float>
{
public:
TestDataFloatWQInt8FullyConnected()
{
_input_data = fully_connected_float_weights_quantized_int8::input_data;
_reference_output_data = fully_connected_float_weights_quantized_int8::reference_output_data;
_test_kernel_model_circle =
fully_connected_float_weights_quantized_int8::test_kernel_model_circle;
}

~TestDataFloatWQInt8FullyConnected() override = default;
};

} // namespace test_model
} // namespace onert_micro

Expand Down
34 changes: 29 additions & 5 deletions onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,35 @@ onert_micro::execute::execute_kernel_CircleFullyConnected(const OMExecuteArgs &e
if (status != Ok)
return status;

status =
pal::FullyConnected(params, core::utils::castInputData<float>(input_data),
OMRuntimeShape(weight), core::utils::castInputData<float>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
switch (weight->type())
{
case circle::TensorType_FLOAT32:
{

status = pal::FullyConnected(
params, core::utils::castInputData<float>(input_data), OMRuntimeShape(weight),
core::utils::castInputData<float>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
}
break;
case circle::TensorType_INT8:
{
// weight quantized INT8 mode
params.weights_scales =
reinterpret_cast<const float *>(weight->quantization()->scale()->data());
params.is_channel_wise_quant = weight->quantization()->scale()->size() > 1;

status = pal::FullyConnected(
params, core::utils::castInputData<float>(input_data), OMRuntimeShape(weight),
core::utils::castInputData<int8_t>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
}
break;
default:
assert(false && "Unsupported hybrid weight type");
}
}
break;
#endif // DIS_FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ TEST_F(FullyConnectedTest, Float_P)
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

// test hybrid kernel input:float32 + weight:int8
TEST_F(FullyConnectedTest, FloatWQInt8_P)
{
onert_micro::test_model::TestDataFloatWQInt8FullyConnected test_data_kernel;
std::vector<float> output_data_vector =
onert_micro::execute::testing::checkKernel<float>(1, &test_data_kernel);
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

TEST_F(FullyConnectedTest, S8_P)
{
onert_micro::test_model::TestDataS8FullyConnected test_data_kernel;
Expand Down
64 changes: 58 additions & 6 deletions onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ constexpr uint32_t outputTensorIdx = 0;
OMStatus
onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs &config_args)
{

OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;
OMRuntimeStorage &runtime_storage = config_args.runtime_storage;
Expand All @@ -50,7 +51,6 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs
const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *weight = runtime_kernel.inputs[weightTensorIdx];
const circle::Tensor *bias = runtime_kernel.inputs[biasTensorIdx];

const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
Expand All @@ -60,13 +60,65 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs

OMStatus status = Ok;

if ((input->type() == circle::TensorType_FLOAT32 &&
weight->type() != circle::TensorType_FLOAT32) or
(input->type() == circle::TensorType_INT8 && weight->type() != circle::TensorType_INT8) or
(input->type() == circle::TensorType_INT16 && weight->type() != circle::TensorType_INT16))
#ifndef DIS_FLOAT
if (weight->type() == circle::TensorType_FLOAT32)
{
return UnsupportedType;

status = utils::checkCondition(input->type() == circle::TensorType_FLOAT32 and
output->type() == circle::TensorType_FLOAT32 and
(!bias or bias->type() == circle::TensorType_FLOAT32));
if (status != Ok)
return status;
}
#endif // DIS_FLOAT
#ifndef DIS_QUANT
if (weight->type() == circle::TensorType_UINT8)
{

status = utils::checkCondition(input->type() == circle::TensorType_UINT8 and
output->type() == circle::TensorType_UINT8 and
(!bias or bias->type() == circle::TensorType_INT32));
if (status != Ok)
return status;
}
else if (weight->type() == circle::TensorType_INT8)
stamalakhov marked this conversation as resolved.
Show resolved Hide resolved
{
status = utils::checkCondition(input->type() == circle::TensorType_INT8 or
input->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

status = utils::checkCondition(output->type() == circle::TensorType_INT8 or
output->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

status = utils::checkCondition(!bias or bias->type() == circle::TensorType_INT32 or
bias->type() == circle::TensorType_INT64 or
bias->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

if (input->type() == circle::TensorType_FLOAT32)
{
// hybrid mode
// Check it is channel wise quantization
status = utils::checkCondition(weight->quantization() != nullptr and
weight->quantization()->scale() != nullptr);
if (status != Ok)
return status;
}
}
else if (weight->type() == circle::TensorType_INT16)
{

status = utils::checkCondition(input->type() == circle::TensorType_INT16 and
output->type() == circle::TensorType_INT16 and
(!bias or bias->type() == circle::TensorType_INT32));
if (status != Ok)
return status;
}
#endif // DIS_QUANT

core::OMRuntimeShape weight_shape(weight);
core::OMRuntimeShape bias_shape(bias);
Expand Down