Skip to content

Commit

Permalink
[onert-micro] Reduce code duplication
Browse files Browse the repository at this point in the history
This pr reduces code duplication for pooling execute part and for Mul and Add ops.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Jun 24, 2024
1 parent 22bce71 commit 6120d7c
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 370 deletions.
1 change: 1 addition & 0 deletions onert-micro/onert-micro/include/execute/OMUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace onert_micro
namespace execute
{

void readQuantParams(const circle::Tensor *tensor, long &zero_point, float &scale);
template <typename T>
OMStatus calculateActivationRange(circle::ActivationFunctionType activation, T *activation_min,
T *activation_max)
Expand Down
49 changes: 49 additions & 0 deletions onert-micro/onert-micro/include/execute/kernels/PoolingCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_KERNELS_POOLING_COMMON_H
#define ONERT_MICRO_EXECUTE_KERNELS_POOLING_COMMON_H

#include "OMStatus.h"

#include "core/OMUtils.h"
#include "core/OMKernelData.h"

#include "execute/OMKernelExecutionBuilder.h"
#include "execute/OMUtils.h"
#include "execute/OMRuntimeKernel.h"
#include <functional>

namespace onert_micro
{
namespace execute
{

OMStatus execute_pooling_common(
const OMExecuteArgs &execute_args,
const std::function<OMStatus(const core::Pool2DParams &params,
const core::OMRuntimeShape &input_shape, const float *input_data,
const core::OMRuntimeShape &output_shape, float *output_data)>
&f_float,
const std::function<OMStatus(const core::Pool2DParams &params,
const core::OMRuntimeShape &input_shape, const int8_t *input_data,
const core::OMRuntimeShape &output_shape, int8_t *output_data)>
&f_int8);

} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_KERNELS_POOLING_COMMON_H
1 change: 1 addition & 0 deletions onert-micro/onert-micro/src/execute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(SOURCES
OMRuntimeKernel.cpp
OMUtils.cpp
kernels/ConvolutionCommon.cpp
kernels/PoolingCommon.cpp
)

# Add configure kernels
Expand Down
16 changes: 16 additions & 0 deletions onert-micro/onert-micro/src/execute/OMUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,19 @@ OMStatus onert_micro::execute::calculateActivationRangeQuantized(
return calculateActivationRangeQuantizedImpl(activation, qmin, qmax, output_zero_point,
output_scale, activation_min, activation_max);
}

void onert_micro::execute::readQuantParams(const circle::Tensor *tensor, long &zero_point,
float &scale)
{
// additional check
assert(tensor->quantization() != nullptr); // Fix caller
assert(tensor->quantization()->scale() != nullptr and
tensor->quantization()->scale()->size() == 1); // Fix caller
assert(tensor->quantization()->zero_point() != nullptr and
tensor->quantization()->zero_point()->size() == 1); // Fix caller

// read zero point
zero_point = tensor->quantization()->zero_point()->operator[](0);
// read scale
scale = tensor->quantization()->scale()->operator[](0);
}
40 changes: 14 additions & 26 deletions onert-micro/onert-micro/src/execute/kernels/Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,20 @@ void calculateQuantParams(core::ArithmeticQuantParams &params, const circle::Ten
const circle::Tensor *input2, const circle::Tensor *output,
circle::ActivationFunctionType act)
{
assert(input1->quantization() != nullptr); // Fix caller
assert(input2->quantization() != nullptr); // Fix caller
assert(output->quantization() != nullptr); // Fix caller

assert(input1->quantization()->scale() != nullptr and
input1->quantization()->scale()->size() == 1); // Fix caller
assert(input2->quantization()->scale() != nullptr and
input2->quantization()->scale()->size() == 1); // Fix caller
assert(output->quantization()->scale() != nullptr and
output->quantization()->scale()->size() == 1); // Fix caller

assert(input1->quantization()->zero_point() != nullptr and
input1->quantization()->zero_point()->size() == 1); // Fix caller
assert(input2->quantization()->zero_point() != nullptr and
input2->quantization()->zero_point()->size() == 1); // Fix caller
assert(output->quantization()->zero_point() != nullptr and
output->quantization()->zero_point()->size() == 1); // Fix caller

// 8bit -> 8bit general quantized path, with general rescalings
const auto input1_zp = input1->quantization()->zero_point()->operator[](0);
const auto input2_zp = input2->quantization()->zero_point()->operator[](0);
const auto output_zp = output->quantization()->zero_point()->operator[](0);

const auto input1_scale = input1->quantization()->scale()->operator[](0);
const auto input2_scale = input2->quantization()->scale()->operator[](0);
const auto output_scale = output->quantization()->scale()->operator[](0);
long input1_zp;
long input2_zp;
long output_zp;

float input1_scale;
float input2_scale;
float output_scale;

// Read input1 quant params
readQuantParams(input1, input1_zp, input1_scale);
// Read input2 quant params
readQuantParams(input2, input2_zp, input2_scale);
// Read output quant params
readQuantParams(output, output_zp, output_scale);

params.input1_offset = -static_cast<int32_t>(input1_zp);
params.input2_offset = -static_cast<int32_t>(input2_zp);
Expand Down
129 changes: 17 additions & 112 deletions onert-micro/onert-micro/src/execute/kernels/AveragePool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,128 +14,33 @@
* limitations under the License.
*/

#include "OMStatus.h"

#include "core/OMUtils.h"
#include "core/OMKernelData.h"

#include "execute/OMKernelExecutionBuilder.h"
#include "execute/OMUtils.h"
#include "execute/OMRuntimeKernel.h"

#include "execute/kernels/PoolingCommon.h"
#include "PALAveragePool2D.h"

using namespace onert_micro;
using namespace onert_micro::execute;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

// NOTE: doesnt currently support dynamic shapes
OMStatus onert_micro::execute::execute_kernel_CircleAveragePool2D(const OMExecuteArgs &execute_args)
{
core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
uint16_t op_index = execute_args.kernel_index;

const circle::Tensor *input = nullptr;
const circle::Tensor *output = nullptr;

uint8_t *input_data = nullptr;
uint8_t *output_data = nullptr;

OMStatus status = Ok;

const circle::Pool2DOptions *options = nullptr;
{
OMRuntimeKernel runtime_kernel;
runtime_kernel.readKernel(op_index, runtime_context);

input = runtime_kernel.inputs[inputTensorIdx];
output = runtime_kernel.outputs[outputTensorIdx];
auto avg_pool_float_lambda = [](const core::Pool2DParams &params,
const core::OMRuntimeShape &input_shape, const float *input_data,
const core::OMRuntimeShape &output_shape, float *output_data) {
return pal::AveragePool(params, input_shape, input_data, output_shape, output_data);
};

assert(input != nullptr);
assert(output != nullptr);

status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
if (status != Ok)
return status;

input_data = runtime_kernel.inputs_data[inputTensorIdx];
output_data = runtime_kernel.outputs_data[outputTensorIdx];

options = runtime_kernel.first_operator->builtin_options_as_Pool2DOptions();
}

assert(input_data != nullptr);
assert(output_data != nullptr);
assert(options != nullptr);

core::OMRuntimeShape input_shape(input);

int32_t padding_h = 0;
int32_t padding_w = 0;

const int input_width = input_shape.dims(2);
const int input_height = input_shape.dims(1);
execute::computePaddingHeightWidth(
options->stride_h(), options->stride_w(), 1 /* dilation_rate_height */,
1 /* dilation_rate_width */, input_height, input_width, options->filter_height(),
options->filter_width(), options->padding(), &padding_h, &padding_w);

core::Pool2DParams params{};
params.pad_h = padding_h;
params.pad_w = padding_w;
params.stride_h = options->stride_h();
params.stride_w = options->stride_w();
params.filter_h = options->filter_height();
params.filter_w = options->filter_width();

switch (input->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{
calculateActivationRange(options->fused_activation_function(), &params.activation_min,
&params.activation_max);
status = pal::AveragePool(params, input_shape, core::utils::castInputData<float>(input_data),
core::OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
}
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case circle::TensorType_INT8:
{
assert(output->quantization() != nullptr);
assert(output->quantization()->scale() != nullptr);
assert(output->quantization()->scale()->size() == 1);
const auto output_scale = output->quantization()->scale()->operator[](0);

assert(output->quantization()->zero_point() != nullptr);
assert(output->quantization()->zero_point()->size() == 1);
const auto output_zp = output->quantization()->zero_point()->operator[](0);

calculateActivationRangeQuantized(
options->fused_activation_function(), output_zp, output_scale, output->type(),
&params.quantized_activation_min, &params.quantized_activation_max);
status = pal::AveragePool(params, input_shape, core::utils::castInputData<int8_t>(input_data),
core::OMRuntimeShape(output),
core::utils::castOutputData<int8_t>(output_data));
}
break;
auto avg_pool_int8_lambda = [](const core::Pool2DParams &params,
const core::OMRuntimeShape &input_shape, const int8_t *input_data,
const core::OMRuntimeShape &output_shape, int8_t *output_data) {
return pal::AveragePool(params, input_shape, input_data, output_shape, output_data);
};
#else
auto avg_pool_int8_lambda = [](const core::Pool2DParams &params,
const core::OMRuntimeShape &input_shape, const int8_t *input_data,
const core::OMRuntimeShape &output_shape,
int8_t *output_data) { return UnsupportedType; };
#endif // DIS_QUANT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
}
}

return status;
return execute_pooling_common(execute_args, avg_pool_float_lambda, avg_pool_int8_lambda);
}
Loading

0 comments on commit 6120d7c

Please sign in to comment.