From 0140d19c683a64b4b5562103dab3a6c152e24a0f Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Wed, 26 Jun 2024 14:30:37 +0300 Subject: [PATCH] add leaky relu --- .../src/execute/kernels/LeakyRelu.cpp | 83 +------------------ .../src/execute/kernels/ReluCommon.cpp | 8 +- 2 files changed, 10 insertions(+), 81 deletions(-) diff --git a/onert-micro/onert-micro/src/execute/kernels/LeakyRelu.cpp b/onert-micro/onert-micro/src/execute/kernels/LeakyRelu.cpp index 9d483266af2..d2b6467ca52 100644 --- a/onert-micro/onert-micro/src/execute/kernels/LeakyRelu.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/LeakyRelu.cpp @@ -14,91 +14,14 @@ * limitations under the License. */ -#include "OMStatus.h" - -#include "core/OMUtils.h" - -#include "execute/OMKernelExecutionBuilder.h" -#include "execute/OMRuntimeKernel.h" - -#include "PALReluCommon.h" +#include "execute/kernels/ReluCommon.h" using namespace onert_micro; using namespace onert_micro::execute; -namespace -{ - -constexpr uint32_t inputTensorIdx = 0; -constexpr uint32_t outputTensorIdx = 0; - -} // namespace - // NOTE: doesnt currently support dynamic shapes OMStatus onert_micro::execute::execute_kernel_CircleLeakyRelu(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - - const circle::Tensor *input = nullptr; - const circle::Tensor *output = nullptr; - - uint8_t *input_data = nullptr; - uint8_t *output_data = nullptr; - - OMStatus status = Ok; - - OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - input = runtime_kernel.inputs[inputTensorIdx]; - output = runtime_kernel.outputs[outputTensorIdx]; - - assert(input != nullptr); - assert(output != nullptr); - - status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - if (status != Ok) - return status; - - input_data = runtime_kernel.inputs_data[inputTensorIdx]; - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - - const auto *options = runtime_kernel.first_operator->builtin_options_as_LeakyReluOptions(); - - if (options == nullptr) - return UnknownError; - - assert(input_data != nullptr); - assert(output_data != nullptr); - - switch (input->type()) - { -#ifndef DIS_FLOAT - case circle::TensorType_FLOAT32: - { - - core::OMRuntimeShape input_shape(input); - core::OMRuntimeShape output_shape(output); - - const float *input_data_float = core::utils::castInputData(input_data); - float *output_data_float = core::utils::castOutputData(output_data); - - assert(output_data_float); - const int flat_size = input_shape.flatSize(); - - status = - pal::ReLUCommon(flat_size, input_data_float, output_data_float, options->alpha(), false); - } - break; -#endif // DIS_FLOAT - default: - { - status = UnsupportedType; - assert(false && "Unsupported type."); - } - } - - return status; + bool is_relu_6 = false; + return execute_relu_common(execute_args, is_relu_6); } diff --git a/onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp b/onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp index e25b0fa0b7d..ae464bffef9 100644 --- a/onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp @@ -63,6 +63,11 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_ assert(input_data != nullptr); assert(output_data != nullptr); + float alpha = 0.f; + auto options = runtime_kernel.first_operator->builtin_options_as_LeakyReluOptions(); + if (options != nullptr) + alpha = options->alpha(); + switch (input->type()) { #ifndef DIS_FLOAT @@ -77,7 +82,7 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_ assert(output_data_float); const int flat_size = input_shape.flatSize(); - status = pal::ReLUCommon(flat_size, input_data_float, output_data_float, 0.0f, is_relu_6); + status = pal::ReLUCommon(flat_size, input_data_float, output_data_float, alpha, is_relu_6); } break; #endif // DIS_FLOAT @@ -85,6 +90,7 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_ { status = UnsupportedType; assert(false && "Unsupported type."); + break; } }