Skip to content

Commit

Permalink
[onert-micro] Add cmsis-nn FullyConnected kernel (#11564)
Browse files Browse the repository at this point in the history
This commit adds cmsis-nn FullyConnected kernel

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored Sep 21, 2023
1 parent ae16dc0 commit 0eb8ad1
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
REGISTER_KERNEL(LOGISTIC, Logistic)
REGISTER_KERNEL(GATHER, Gather)
REGISTER_KERNEL(EXP, Exp)
REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_KERNEL(GREATER, Greater)
REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
Expand Down
Empty file.
114 changes: 74 additions & 40 deletions onert-micro/luci-interpreter/pal/cmsisnn/PALFullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,26 @@
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#ifndef LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H

#include "PALFullyConnectedCommon.h"

#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
#include <arm_nnfunctions.h>

namespace luci_interpreter_pal
{
template <typename T>
static inline void FullyConnected(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const T *input_data,
const tflite::RuntimeShape &filter_shape, const T *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, T *output_data)
{
{
// MARK: At this moment this operation doesn't support
assert(false && "FullyConnected NYI");
(void)params;
(void)input_shape;
(void)input_data;
(void)filter_shape;
(void)filter_data;
(void)bias_shape;
(void)bias_data;
(void)output_shape;
(void)output_data;
}
}

template <>
inline void
FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const int8_t *input_data,
const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, int8_t *output_data)
inline void FullyConnected<int8_t>(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int8_t *input_data,
const int32_t *filter_shape, const int8_t *filter_data,
const int32_t *bias_data, const int32_t *output_shape,
int8_t *output_data)
{
assert(output_shape.DimensionsCount() == 2);

const int batches = output_shape.Dims(0);
const int output_depth = output_shape.Dims(1);

const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
Expand Down Expand Up @@ -107,8 +81,68 @@ FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
auto res =
arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_MATH_SUCCESS);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

template <>
inline void FullyConnected(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int16_t *input_data, const int32_t *filter_shape,
const int8_t *filter_data, const int64_t *bias_data,
const int32_t *output_shape, int16_t *output_data)
{
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
fc_params.output_offset = params.output_offset;
fc_params.filter_offset = params.weights_offset;
fc_params.activation.min = params.quantized_activation_min;
fc_params.activation.max = params.quantized_activation_max;

cmsis_nn_per_tensor_quant_params quant_params;
quant_params.multiplier = params.output_multiplier;
quant_params.shift = params.output_shift;

cmsis_nn_dims input_dims;
input_dims.n = batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = accum_depth;

cmsis_nn_dims filter_dims;
filter_dims.n = accum_depth;
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_depth;

cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

cmsis_nn_dims output_dims;
output_dims.n = batches;
output_dims.h = 1;
output_dims.w = 1;
output_dims.c = output_depth;

int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
auto buffer = std::make_unique<int8_t[]>(buf_size);
assert(buffer != nullptr);

cmsis_nn_context ctx;
ctx.buf = buffer.get();
ctx.size = buf_size;

auto res =
arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#endif // LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H
2 changes: 1 addition & 1 deletion onert-micro/luci-interpreter/pal/mcu/PALFullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ inline void FullyConnected(const luci_interpreter_pal::FullyConnectedParams &, c
const int32_t *, int16_t *)
{
// MARK: At this moment this operation doesn't support
assert(false && "FullyConnected INT8 NYI");
assert(false && "FullyConnected INT16 NYI");
}

} // namespace luci_interpreter_pal
Expand Down
54 changes: 41 additions & 13 deletions onert-micro/luci-interpreter/src/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ void evalFloat(const circle::Tensor *input, const circle::Tensor *weights,
#ifndef DIS_QUANT
void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,
const circle::Tensor *bias, const circle::Tensor *output,
const circle::FullyConnectedOptions *options, BaseRuntimeGraph *runtime_graph)
const circle::FullyConnectedOptions *options, BaseRuntimeGraph *runtime_graph,
DataType type)
{
double real_multiplier = 0.0;
int output_shift;
Expand All @@ -80,7 +81,9 @@ void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,
&output_activation_max);

int32_t input_offset = -Tensor::zero_point(input);
int32_t filter_offset = -Tensor::zero_point(weights);
int32_t filter_offset = 0;
if (type == DataType::U8)
filter_offset = -Tensor::zero_point(weights);
int32_t output_offset = Tensor::zero_point(output);

luci_interpreter_pal::FullyConnectedParams op_params{};
Expand Down Expand Up @@ -112,11 +115,31 @@ void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,

int32_t output_shape[kMaxSmallSize];
kernels::getTensorDims(output, runtime_graph, output_shape);

luci_interpreter_pal::FullyConnected(
op_params, input_shape, kernels::getTensorData<uint8_t>(input_data), weights_shape,
kernels::getTensorData<uint8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<uint8_t>(output_data));
if (type == DataType::S8)
{
luci_interpreter_pal::FullyConnected<int8_t>(
op_params, input_shape, kernels::getTensorData<int8_t>(input_data), weights_shape,
kernels::getTensorData<int8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<int8_t>(output_data));
}
else if (type == DataType::U8)
{
luci_interpreter_pal::FullyConnected<uint8_t>(
op_params, input_shape, kernels::getTensorData<uint8_t>(input_data), weights_shape,
kernels::getTensorData<uint8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<uint8_t>(output_data));
}
else if (type == DataType::S16)
{
luci_interpreter_pal::FullyConnected(
op_params, input_shape, kernels::getTensorData<int16_t>(input_data), weights_shape,
kernels::getTensorData<int8_t>(weights_data), kernels::getTensorData<int64_t>(bias_data),
output_shape, kernels::getTensorData<int16_t>(output_data));
}
else
{
assert(false && "Unsupported quantize type");
}
}
#endif

Expand Down Expand Up @@ -160,9 +183,12 @@ void configure_kernel_CircleFullyConnected(const circle::Operator *cur_op,
}
else if (Tensor::element_type(weights) == DataType::S8)
{
LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8);
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::S8);
LUCI_INTERPRETER_CHECK(!bias || Tensor::element_type(bias) == DataType::S32)
LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8 ||
Tensor::element_type(input) == DataType::S16);
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::S8 ||
Tensor::element_type(output) == DataType::S16);
LUCI_INTERPRETER_CHECK(!bias || Tensor::element_type(bias) == DataType::S32 ||
Tensor::element_type(bias) == DataType::S64)
}
#endif // DIS_QUANT
else
Expand Down Expand Up @@ -210,12 +236,14 @@ void execute_kernel_CircleFullyConnected(const circle::Operator *cur_op,
assert(output != nullptr);

const auto *options = cur_op->builtin_options_as_FullyConnectedOptions();

switch (Tensor::element_type(input))
const auto input_type = Tensor::element_type(input);
switch (input_type)
{
#ifndef DIS_QUANT
case DataType::U8:
evalQuantized(input, weights, bias, output, options, runtime_graph);
case DataType::S8:
case DataType::S16:
evalQuantized(input, weights, bias, output, options, runtime_graph, input_type);
break;
#endif // DIS_QUANT
#ifndef DIS_FLOAT
Expand Down

0 comments on commit 0eb8ad1

Please sign in to comment.