Skip to content

Commit

Permalink
[onert-micro] Add GRU forward execution
Browse files Browse the repository at this point in the history
This pr adds GRU forward execution.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]
  • Loading branch information
Artem Balyshev committed Aug 22, 2024
1 parent 987fafa commit 3fa104f
Show file tree
Hide file tree
Showing 20 changed files with 852 additions and 18 deletions.
2 changes: 1 addition & 1 deletion onert-micro/eval-driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int entry(int argc, char **argv)
}

// Do inference.
interpreter.run();
interpreter.run(config);
}

// Get output.
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/OMInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OMInterpreter

OMStatus importModel(const char *model_ptr, const OMConfig &config);

OMStatus run();
OMStatus run(const OMConfig &config);

OMStatus reset();

Expand Down
7 changes: 4 additions & 3 deletions onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ class OMTrainingInterpreter
// Note: calculation will be done on test_size number of test samples
// Warning: before using evaluateMetric call: 1) importTrainModel; 2) setInput; 3) setTarget
// Note: number of the samples in data should be equal to the test_size
OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size)
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val,
uint32_t test_size)
{
return _training_runtime_module.evaluateMetric(metric, metric_val, test_size);
return _training_runtime_module.evaluateMetric(config, metric, metric_val, test_size);
}

// To get input and output flat size
Expand All @@ -86,7 +87,7 @@ class OMTrainingInterpreter
// Load current status from checkpoint and save it in current model and in current config
OMStatus loadCheckpoint(OMConfig &config, const char *load_path);

OMStatus run() { return _training_runtime_module.run(); }
OMStatus run(const OMConfig &config) { return _training_runtime_module.run(config); }
OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); }

void *getInputData(uint32_t position);
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/core/OMRuntimeModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class OMRuntimeModule
~OMRuntimeModule() = default;

OMStatus importModel(const char *model_ptr, const OMConfig &config);
OMStatus run();
OMStatus run(const OMConfig &config);
OMStatus reset();

uint32_t getNumberOfInputs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
// 2) metric_val should be initialized with some value before calling this method due to
// after calculation for current batch_num (the sequence number of the current sample)
// this value is added to metric_val
OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size);
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val,
uint32_t test_size);

// Set input data for input with input_index
// Note: number of the samples in data should be equal to the batch_size in config structure
Expand Down
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/execute/OMExecuteArgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct OMExecuteArgs
core::OMRuntimeContext &runtime_context;
uint16_t kernel_index;
core::OMRuntimeModule &runtime_module;
uint32_t num_train_layers = 0;
bool is_train_mode = false;
};

} // namespace execute
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/execute/OMRuntimeKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#include <cstdint>

constexpr static uint32_t maxInputSize = 5;
constexpr static uint32_t maxInputSize = 6;
constexpr static uint32_t maxOutputSize = 5;

namespace onert_micro
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/execute/OMTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::vector<U> checkKernel(uint32_t num_inputs,
}
}

interpreter.run();
interpreter.run(config);

U *output_data = reinterpret_cast<U *>(interpreter.getOutputDataAt(0));
const size_t num_elements = interpreter.getOutputSizeAt(0);
Expand Down
209 changes: 209 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALGRUCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H
#define ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H

#include "OMStatus.h"
#include "core/OMRuntimeShape.h"

#include "PALUtils.h"
#include "ProcessBroadcastShapes.h"
#include "PALFullyConnected.h"
#include "PALLogistic.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{
namespace
{
void calculateGRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, float *output_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, float *output_input_data,
float *output_hidden_data, const core::OMRuntimeShape &output_shape_fc,
float *intermediate_buffer)
{
core::FullyConnectedParams op_params{};
// As FC nodes doesn't have any activations inside GRU, let' use just numeric limits
op_params.float_activation_min = std::numeric_limits<float>::lowest();
op_params.float_activation_max = std::numeric_limits<float>::max();
// If intermediate_buffer != nullptr - then it is train mode and we need save intermediate inform
bool is_train_mode = intermediate_buffer != nullptr;
if (is_train_mode)
{
// Copy input for FC Input to calculate weights gradients
std::memcpy(intermediate_buffer, output_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
// FC Input
FullyConnected(op_params, output_data, weight_input_shape, weight_input_data, bias_input_data,
output_shape_fc, output_input_data);

// FC Hidden
// Note: input for this FC node will be saved without intermediate buffer
FullyConnected(op_params, input_data, weight_hidden_shape, weight_hidden_data, bias_hidden_data,
output_shape_fc, output_hidden_data);

int num_elements = output_shape_fc.dims(1) / 3;

float *second_hidden_part = output_hidden_data + num_elements;
float *second_input_part = output_input_data + num_elements;

float *third_hidden_part = second_hidden_part + num_elements;
float *third_input_part = second_input_part + num_elements;

// Calculate Left part
for (int i = 0; i < num_elements; ++i)
{
output_input_data[i] += output_hidden_data[i];
}

// If train mode - save logistic input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
Logistic(num_elements, output_input_data, output_input_data);

// If train mode - save most left mul input (right input)
if (is_train_mode)
{
std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
// Calculate most left mul
float *most_left_part_final = output_input_data;
float *first_part = output_input_data;
for (int i = 0; i < num_elements; ++i)
{
output_data[i] *= most_left_part_final[i];
first_part[i] = 1.0f - first_part[i];
}

// Calc second part
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] += second_input_part[i];
}
// If train mode - save logistic input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
Logistic(num_elements, second_hidden_part, second_hidden_part);

// If train mode - save mul input (left and right)
if (is_train_mode)
{
// Left input
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();

// Right input
std::memcpy(intermediate_buffer, third_input_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= third_input_part[i];
second_hidden_part[i] += third_hidden_part[i];
}
// If train mode - save tanh input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] = std::tanh(second_hidden_part[i]);
}

// If train mode - save mul input (left and right)
if (is_train_mode)
{
// Left input
std::memcpy(intermediate_buffer, first_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();

// Right input
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= first_part[i];
output_data[i] += second_hidden_part[i];
}
}

} // namespace

OMStatus GRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, const float *hidden_state_data, float *output_data,
float *output_input_data, float *output_hidden_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size,
float *intermediate_buffer)
{
const int32_t time = input_shape.dims(0);

core::OMRuntimeShape output_shape_fc(2);
output_shape_fc.setDim(0, 1);
output_shape_fc.setDim(1, weight_hidden_shape.dims(0));

std::memcpy(output_data, hidden_state_data, output_shape.flatSize() * sizeof(float));

for (int i = 0; i < time; ++i)
{
calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data,
bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape,
weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc,
intermediate_buffer);
input_data += input_shape.dims(2);
if (intermediate_buffer_size != 0)
{
assert(intermediate_buffer != nullptr);
intermediate_buffer += intermediate_buffer_size;
}
}
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H
1 change: 1 addition & 0 deletions onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ REGISTER_KERNEL(GATHER_ND, GatherND)
REGISTER_KERNEL(EXP, Exp)
REGISTER_KERNEL(GREATER, Greater)
REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)
REGISTER_KERNEL(GRU, GRU)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
REGISTER_KERNEL(ELU, Elu)
REGISTER_KERNEL(EQUAL, Equal)
Expand Down
23 changes: 23 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_H
#define ONERT_MICRO_EXECUTE_PAL_GRU_H

#include "PALGRUCommon.h"

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_H
Loading

0 comments on commit 3fa104f

Please sign in to comment.