diff --git a/onert-micro/onert-micro/include/core/OMKernelData.h b/onert-micro/onert-micro/include/core/OMKernelData.h index d0ab251777e..64bf96f0521 100644 --- a/onert-micro/onert-micro/include/core/OMKernelData.h +++ b/onert-micro/onert-micro/include/core/OMKernelData.h @@ -186,6 +186,8 @@ struct FullyConnectedParams int32_t weights_offset; int32_t output_offset; int32_t output_multiplier; + const float *weights_scales; + bool is_channel_wise_quant; int output_shift; // uint8_t, etc, activation params. int32_t quantized_activation_min; diff --git a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedCommon.h b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedCommon.h index e0cd74cf8f5..69908232510 100644 --- a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedCommon.h +++ b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedCommon.h @@ -76,12 +76,11 @@ OMStatus FullyConnected(const core::FullyConnectedParams ¶ms, const InputTyp return Ok; } -template <> -OMStatus inline FullyConnected(const core::FullyConnectedParams ¶ms, - const float *input_data, - const core::OMRuntimeShape &filter_shape, - const float *filter_data, const float *bias_data, - const core::OMRuntimeShape &output_shape, float *output_data) +template +OMStatus inline FullyConnected(const core::FullyConnectedParams ¶ms, const float *input_data, + const core::OMRuntimeShape &filter_shape, + const WeightType *filter_data, const float *bias_data, + const core::OMRuntimeShape &output_shape, float *output_data) { const float output_activation_min = params.float_activation_min; const float output_activation_max = params.float_activation_max; @@ -93,12 +92,24 @@ OMStatus inline FullyConnected(const core::FullyConnectedParams ¶ms, for (int b = 0; b < batches; ++b) { + const float *weight_scale_ptr = params.weights_scales; for (int out_c = 0; out_c < output_depth; ++out_c) { float total = 0.f; for (int d = 0; d < accum_depth; ++d) { - total += input_data[b * accum_depth + d] * filter_data[out_c * accum_depth + d]; + auto input_value = input_data[b * accum_depth + d]; + if (std::is_same::value) + { + total += input_value * filter_data[out_c * accum_depth + d]; + } + else + { + const float filter_scale = *weight_scale_ptr; + const float filter_value = + static_cast(filter_data[out_c * accum_depth + d]) * filter_scale; + total += input_value * filter_value; + } } float bias_value = 0.0f; if (bias_data) @@ -107,6 +118,12 @@ OMStatus inline FullyConnected(const core::FullyConnectedParams ¶ms, } output_data[out_c + output_depth * b] = std::min(std::max(total + bias_value, output_activation_min), output_activation_max); + + if (std::is_same::value) + { + if (params.is_channel_wise_quant) + weight_scale_ptr++; + } } } return Ok; diff --git a/onert-micro/onert-micro/include/test_models/fully_connected/FloatFullyConnectedKernel.h b/onert-micro/onert-micro/include/test_models/fully_connected/FloatFullyConnectedKernel.h index 00442fa939a..6f5f62db939 100644 --- a/onert-micro/onert-micro/include/test_models/fully_connected/FloatFullyConnectedKernel.h +++ b/onert-micro/onert-micro/include/test_models/fully_connected/FloatFullyConnectedKernel.h @@ -96,6 +96,82 @@ const std::vector reference_output_data = {263.84323, 260.84323, 259.8432 } // namespace fully_connected_float +namespace fully_connected_float_weights_quantized_int8 +{ + +/* + * FullyConnected Kernel: + * Input - float32 + * Weight - int8 + * Bias - float32 + * Out - float32 + * + * Input(1, 4) Weight(4, 4) Bias(4) + * \ | / + * \ | / + * FullyConnected + * | + * Output(1, 4) + */ + +const unsigned char test_kernel_model_circle[] = { + 0x20, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, + 0x12, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0xe0, 0x02, 0x00, 0x00, + 0xc8, 0x02, 0x00, 0x00, 0x24, 0x02, 0x00, 0x00, 0xc0, 0x01, 0x00, 0x00, 0x90, 0x01, 0x00, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xcc, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x5f, 0x6f, 0x70, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65, + 0x00, 0x00, 0x00, 0x00, 0x22, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x4f, 0x4e, 0x45, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65, + 0x00, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x7c, 0x01, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x90, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0xe0, 0xfe, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00, 0xc4, 0xfe, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xff, 0xff, 0xff, + 0x34, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x61, 0x73, 0x00, 0x00, 0x00, 0x00, 0xa6, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x40, 0xc0, + 0x00, 0x00, 0x80, 0x40, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x09, 0x94, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x55, 0x7f, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x00, + 0x00, 0x00, 0x00, 0x7f, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x83, 0xc1, 0x3c, 0x04, 0x02, 0x01, 0x3d, + 0x85, 0x42, 0x21, 0x3d, 0x06, 0x83, 0x41, 0x3d, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x00, 0x00, 0xf0, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00}; + +const std::vector input_data = {17.491695, 15.660671, 4.7347794, -15.796822}; + +const std::vector reference_output_data = {-19.529659, 60.642685, 20.673897, -90.780930}; + +} // namespace fully_connected_float_weights_quantized_int8 + class TestDataFloatFullyConnected : public TestDataFullyConnectedBase { public: @@ -109,6 +185,20 @@ class TestDataFloatFullyConnected : public TestDataFullyConnectedBase ~TestDataFloatFullyConnected() override = default; }; +class TestDataFloatWQInt8FullyConnected : public TestDataFullyConnectedBase +{ +public: + TestDataFloatWQInt8FullyConnected() + { + _input_data = fully_connected_float_weights_quantized_int8::input_data; + _reference_output_data = fully_connected_float_weights_quantized_int8::reference_output_data; + _test_kernel_model_circle = + fully_connected_float_weights_quantized_int8::test_kernel_model_circle; + } + + ~TestDataFloatWQInt8FullyConnected() override = default; +}; + } // namespace test_model } // namespace onert_micro diff --git a/onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp index 981e93df324..89d3482a3fc 100644 --- a/onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp @@ -147,11 +147,35 @@ onert_micro::execute::execute_kernel_CircleFullyConnected(const OMExecuteArgs &e if (status != Ok) return status; - status = - pal::FullyConnected(params, core::utils::castInputData(input_data), - OMRuntimeShape(weight), core::utils::castInputData(weight_data), - core::utils::castInputData(bias_data), OMRuntimeShape(output), - core::utils::castOutputData(output_data)); + switch (weight->type()) + { + case circle::TensorType_FLOAT32: + { + + status = pal::FullyConnected( + params, core::utils::castInputData(input_data), OMRuntimeShape(weight), + core::utils::castInputData(weight_data), + core::utils::castInputData(bias_data), OMRuntimeShape(output), + core::utils::castOutputData(output_data)); + } + break; + case circle::TensorType_INT8: + { + // weight quantized INT8 mode + params.weights_scales = + reinterpret_cast(weight->quantization()->scale()->data()); + params.is_channel_wise_quant = weight->quantization()->scale()->size() > 1; + + status = pal::FullyConnected( + params, core::utils::castInputData(input_data), OMRuntimeShape(weight), + core::utils::castInputData(weight_data), + core::utils::castInputData(bias_data), OMRuntimeShape(output), + core::utils::castOutputData(output_data)); + } + break; + default: + assert(false && "Unsupported hybrid weight type"); + } } break; #endif // DIS_FLOAT diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/FullyConnected.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/FullyConnected.test.cpp index 5085341b761..a61e9cda715 100644 --- a/onert-micro/onert-micro/src/execute/kernels/tests/FullyConnected.test.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/tests/FullyConnected.test.cpp @@ -41,6 +41,15 @@ TEST_F(FullyConnectedTest, Float_P) EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0)); } +// test hybrid kernel input:float32 + weight:int8 +TEST_F(FullyConnectedTest, FloatWQInt8_P) +{ + onert_micro::test_model::TestDataFloatWQInt8FullyConnected test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0)); +} + TEST_F(FullyConnectedTest, S8_P) { onert_micro::test_model::TestDataS8FullyConnected test_data_kernel; diff --git a/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp index e7bd5a4b71a..f9e401e9dbd 100644 --- a/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp @@ -40,6 +40,7 @@ constexpr uint32_t outputTensorIdx = 0; OMStatus onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs &config_args) { + OMRuntimeContext &runtime_context = config_args.runtime_context; uint16_t op_index = config_args.kernel_index; OMRuntimeStorage &runtime_storage = config_args.runtime_storage; @@ -50,7 +51,6 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; const circle::Tensor *weight = runtime_kernel.inputs[weightTensorIdx]; const circle::Tensor *bias = runtime_kernel.inputs[biasTensorIdx]; - const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; assert(input != nullptr); @@ -60,13 +60,65 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs OMStatus status = Ok; - if ((input->type() == circle::TensorType_FLOAT32 && - weight->type() != circle::TensorType_FLOAT32) or - (input->type() == circle::TensorType_INT8 && weight->type() != circle::TensorType_INT8) or - (input->type() == circle::TensorType_INT16 && weight->type() != circle::TensorType_INT16)) +#ifndef DIS_FLOAT + if (weight->type() == circle::TensorType_FLOAT32) { - return UnsupportedType; + + status = utils::checkCondition(input->type() == circle::TensorType_FLOAT32 and + output->type() == circle::TensorType_FLOAT32 and + (!bias or bias->type() == circle::TensorType_FLOAT32)); + if (status != Ok) + return status; + } +#endif // DIS_FLOAT +#ifndef DIS_QUANT + if (weight->type() == circle::TensorType_UINT8) + { + + status = utils::checkCondition(input->type() == circle::TensorType_UINT8 and + output->type() == circle::TensorType_UINT8 and + (!bias or bias->type() == circle::TensorType_INT32)); + if (status != Ok) + return status; } + else if (weight->type() == circle::TensorType_INT8) + { + status = utils::checkCondition(input->type() == circle::TensorType_INT8 or + input->type() == circle::TensorType_FLOAT32); + if (status != Ok) + return status; + + status = utils::checkCondition(output->type() == circle::TensorType_INT8 or + output->type() == circle::TensorType_FLOAT32); + if (status != Ok) + return status; + + status = utils::checkCondition(!bias or bias->type() == circle::TensorType_INT32 or + bias->type() == circle::TensorType_INT64 or + bias->type() == circle::TensorType_FLOAT32); + if (status != Ok) + return status; + + if (input->type() == circle::TensorType_FLOAT32) + { + // hybrid mode + // Check it is channel wise quantization + status = utils::checkCondition(weight->quantization() != nullptr and + weight->quantization()->scale() != nullptr); + if (status != Ok) + return status; + } + } + else if (weight->type() == circle::TensorType_INT16) + { + + status = utils::checkCondition(input->type() == circle::TensorType_INT16 and + output->type() == circle::TensorType_INT16 and + (!bias or bias->type() == circle::TensorType_INT32)); + if (status != Ok) + return status; + } +#endif // DIS_QUANT core::OMRuntimeShape weight_shape(weight); core::OMRuntimeShape bias_shape(bias);