Skip to content

Commit

Permalink
add leaky relu
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Balyshev committed Jun 26, 2024
1 parent 98f85b6 commit 0140d19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 81 deletions.
83 changes: 3 additions & 80 deletions onert-micro/onert-micro/src/execute/kernels/LeakyRelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,91 +14,14 @@
* limitations under the License.
*/

#include "OMStatus.h"

#include "core/OMUtils.h"

#include "execute/OMKernelExecutionBuilder.h"
#include "execute/OMRuntimeKernel.h"

#include "PALReluCommon.h"
#include "execute/kernels/ReluCommon.h"

using namespace onert_micro;
using namespace onert_micro::execute;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

// NOTE: doesnt currently support dynamic shapes
OMStatus onert_micro::execute::execute_kernel_CircleLeakyRelu(const OMExecuteArgs &execute_args)
{
core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
uint16_t op_index = execute_args.kernel_index;

const circle::Tensor *input = nullptr;
const circle::Tensor *output = nullptr;

uint8_t *input_data = nullptr;
uint8_t *output_data = nullptr;

OMStatus status = Ok;

OMRuntimeKernel runtime_kernel;
runtime_kernel.readKernel(op_index, runtime_context);

input = runtime_kernel.inputs[inputTensorIdx];
output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
if (status != Ok)
return status;

input_data = runtime_kernel.inputs_data[inputTensorIdx];
output_data = runtime_kernel.outputs_data[outputTensorIdx];

const auto *options = runtime_kernel.first_operator->builtin_options_as_LeakyReluOptions();

if (options == nullptr)
return UnknownError;

assert(input_data != nullptr);
assert(output_data != nullptr);

switch (input->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{

core::OMRuntimeShape input_shape(input);
core::OMRuntimeShape output_shape(output);

const float *input_data_float = core::utils::castInputData<float>(input_data);
float *output_data_float = core::utils::castOutputData<float>(output_data);

assert(output_data_float);
const int flat_size = input_shape.flatSize();

status =
pal::ReLUCommon(flat_size, input_data_float, output_data_float, options->alpha(), false);
}
break;
#endif // DIS_FLOAT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
}
}

return status;
bool is_relu_6 = false;
return execute_relu_common(execute_args, is_relu_6);
}
8 changes: 7 additions & 1 deletion onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_
assert(input_data != nullptr);
assert(output_data != nullptr);

float alpha = 0.f;
auto options = runtime_kernel.first_operator->builtin_options_as_LeakyReluOptions();
if (options != nullptr)
alpha = options->alpha();

switch (input->type())
{
#ifndef DIS_FLOAT
Expand All @@ -77,14 +82,15 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_
assert(output_data_float);
const int flat_size = input_shape.flatSize();

status = pal::ReLUCommon(flat_size, input_data_float, output_data_float, 0.0f, is_relu_6);
status = pal::ReLUCommon(flat_size, input_data_float, output_data_float, alpha, is_relu_6);
}
break;
#endif // DIS_FLOAT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
break;
}
}

Expand Down

0 comments on commit 0140d19

Please sign in to comment.