Skip to content

Commit

Permalink
[onert-micro] Quantized S8 for Sub (#13682)
Browse files Browse the repository at this point in the history
- Enable Sub S8 support

ONE-DCO-1.0-Signed-off-by: Chunseok Lee <[email protected]>
  • Loading branch information
chunseoklee authored Aug 19, 2024
1 parent f3d7383 commit d4b32bd
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 0 deletions.
30 changes: 30 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALSubCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ namespace execute
{
namespace pal
{

int8_t SubFunc(int8_t x, int8_t y, const core::ArithmeticQuantParams &params)
{
const int32_t input1_val = params.input1_offset + x;
const int32_t input2_val = params.input2_offset + y;
const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32_t raw_sum = scaled_input1_val - scaled_input2_val;
const int32_t raw_output = multiplyByQuantizedMultiplierSmallerThanOneExp(
raw_sum, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output = std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
return static_cast<int8_t>(clamped_output);
}

template <typename T>
OMStatus Sub(const core::BinaryArithmeticBroadcastParams &params, const int flat_size,
const T *input1_data, const T *input2_data, T *output_data)
Expand All @@ -44,6 +64,16 @@ OMStatus BroadcastSub4DSlow(const core::BinaryArithmeticBroadcastParams &params,
return Ok;
}

OMStatus BroadcastSub4DSlow(const core::ArithmeticQuantParams &params,
const core::OMRuntimeShape &input1_shape, const int8_t *input1_data,
const core::OMRuntimeShape &input2_shape, const int8_t *input2_data,
const core::OMRuntimeShape &output_shape, int8_t *output_data)
{
BroadcastBinaryFunction6DSlow(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data, SubFunc);
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro
Expand Down
19 changes: 19 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALSub.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,24 @@
#define ONERT_MICRO_EXECUTE_PAL_SUB_H

#include "PALSubCommon.h"
#include "PALUtils.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{

OMStatus Sub(const core::ArithmeticQuantParams &params, const uint32_t flat_size,
const int8_t *input1_data, const int8_t *input2_data, int8_t *output_data)
{
ElementWise(flat_size, params, input1_data, input2_data, output_data, SubFunc);
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_MUL_H
116 changes: 116 additions & 0 deletions onert-micro/onert-micro/include/test_models/sub/S8SubKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TEST_MODELS_SUB_KERNEL_S8_H
#define ONERT_MICRO_TEST_MODELS_SUB_KERNEL_S8_H

#include "TestDataSubBase.h"

namespace onert_micro
{
namespace test_model
{
namespace sub_s8_no_broadcasting
{

/*
* Sub Kernel:
*
* Input_1(2, 2) Input_2(2, 2)
* \ /
* Sub(w/o broadcast)
* |
* Output(2, 2)
*/

unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x30, 0x00, 0x00, 0x00, 0x44, 0x02, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x88, 0xff, 0xff, 0xff, 0x8c, 0xff, 0xff, 0xff, 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1c, 0x10, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x78, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1a, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x48, 0x00, 0x00, 0x00,
0x0c, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0xfe, 0x42, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x03, 0x00, 0x00, 0x00,
0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x8a, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x4c, 0x00, 0x00, 0x00,
0x7c, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0xfe, 0x42, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x04, 0x00, 0x00, 0x00,
0x69, 0x66, 0x6d, 0x32, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x18, 0x00, 0x14, 0x00, 0x13, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
0x54, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfe, 0x42,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x31,
0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d,
0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};
std::vector<int8_t> input1_data = {
5,
7,
13,
-3,
};
std::vector<int8_t> input2_data = {-5, -11, 5, 5};
std::vector<int8_t> reference_output_data = {10, 18, 8, -8};
} // namespace sub_s8_no_broadcasting

class TestDataS8Sub : public TestDataSubBase<int8_t>
{
public:
explicit TestDataS8Sub(bool is_with_broadcast) : TestDataSubBase<int8_t>(is_with_broadcast)
{
if (is_with_broadcast)
{
std::cerr << ("Sub S8 with broadcasting not supported yet!\n");
}
else
{
_input1_data = sub_s8_no_broadcasting::input1_data;
_input2_data = sub_s8_no_broadcasting::input2_data;
_reference_output_data = sub_s8_no_broadcasting::reference_output_data;
_test_kernel_model_circle = sub_s8_no_broadcasting::test_kernel_model_circle;
}
}

~TestDataS8Sub() override = default;
};

} // namespace test_model
} // namespace onert_micro

#endif // ONERT_MICRO_TEST_MODELS_SUB_KERNEL_S8_H
69 changes: 69 additions & 0 deletions onert-micro/onert-micro/src/execute/kernels/Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,50 @@ constexpr uint32_t input1TensorIdx = 0;
constexpr uint32_t input2TensorIdx = 1;
constexpr uint32_t outputTensorIdx = 0;

void calculateQuantParams(core::ArithmeticQuantParams &params, const circle::Tensor *input1,
const circle::Tensor *input2, const circle::Tensor *output,
circle::ActivationFunctionType act)
{
long input1_zp;
long input2_zp;
long output_zp;

float input1_scale;
float input2_scale;
float output_scale;

// Read input1 quant params
readQuantParams(input1, input1_zp, input1_scale);
// Read input2 quant params
readQuantParams(input2, input2_zp, input2_scale);
// Read output quant params
readQuantParams(output, output_zp, output_scale);

params.input1_offset = -static_cast<int32_t>(input1_zp);
params.input2_offset = -static_cast<int32_t>(input2_zp);
params.output_offset = static_cast<int32_t>(output_zp);
params.left_shift = (output->type() == circle::TensorType_INT16) ? 15 : 20;
const double twice_max_input_scale =
2 * static_cast<double>(std::max(input1_scale, input2_scale));
const double real_input1_multiplier = static_cast<double>(input1_scale) / twice_max_input_scale;
const double real_input2_multiplier = static_cast<double>(input2_scale) / twice_max_input_scale;
const double real_output_multiplier =
twice_max_input_scale / ((1 << params.left_shift) * static_cast<double>(output_scale));

quantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &params.input1_multiplier,
&params.input1_shift);

quantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &params.input2_multiplier,
&params.input2_shift);

quantizeMultiplierSmallerThanOneExp(real_output_multiplier, &params.output_multiplier,
&params.output_shift);

calculateActivationRangeQuantized(act, output_zp, output_scale, output->type(),
&params.quantized_activation_min,
&params.quantized_activation_max);
}

} // namespace

// NOTE: doesnt currently support dynamic shapes
Expand Down Expand Up @@ -160,6 +204,31 @@ OMStatus onert_micro::execute::execute_kernel_CircleSub(const OMExecuteArgs &exe
}
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case circle::TensorType_INT8:
{
core::ArithmeticQuantParams sub_params{};

calculateQuantParams(sub_params, input1, input2, output,
options->fused_activation_function());

if (need_broadcast)
{
status = pal::BroadcastSub4DSlow(
sub_params, input1_shape, core::utils::castInputData<int8_t>(input1_data), input2_shape,
core::utils::castInputData<int8_t>(input2_data), output_shape,
core::utils::castOutputData<int8_t>(output_data));
}
else
{
status = pal::Sub(sub_params, input1_shape.flatSize(),
core::utils::castInputData<int8_t>(input1_data),
core::utils::castInputData<int8_t>(input2_data),
core::utils::castOutputData<int8_t>(output_data));
}
}
break;
#endif // DIF_QUANT
default:
{
status = UnsupportedType;
Expand Down
13 changes: 13 additions & 0 deletions onert-micro/onert-micro/src/execute/kernels/tests/Sub.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "test_models/sub/FloatSubKernel.h"
#include "test_models/sub/NegSubKernel.h"
#include "test_models/sub/IntSubKernel.h"
#include "test_models/sub/S8SubKernel.h"

namespace onert_micro
{
Expand Down Expand Up @@ -53,6 +54,18 @@ TEST_F(SubTest, INT_P)
}
}

TEST_F(SubTest, S8_P)
{
// No broadcast
{
const bool is_with_broadcast = false;
test_model::TestDataS8Sub test_data_add_no_broadcasting(is_with_broadcast);
std::vector<int8_t> output_data_vector =
onert_micro::execute::testing::checkKernel<int8_t>(2, &test_data_add_no_broadcasting);
EXPECT_THAT(output_data_vector, test_data_add_no_broadcasting.get_output_data_by_index(0));
}
}

TEST_F(SubTest, Float_P)
{
// No broadcast
Expand Down

0 comments on commit d4b32bd

Please sign in to comment.