Skip to content

Commit

Permalink
[onert-micro] Add GRU backward execution
Browse files Browse the repository at this point in the history
This pr adds GRU backward execution.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]
  • Loading branch information
Artem Balyshev committed Aug 26, 2024
1 parent 4540b0c commit aa9e945
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 18 deletions.
33 changes: 17 additions & 16 deletions onert-micro/eval-driver/TrainingDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ int entry(int argc, char **argv)
const uint32_t training_epochs = 30;
const float lambda = 0.001f;
const uint32_t BATCH_SIZE = 32;
const uint32_t INPUT_SIZE = 180;
const uint32_t OUTPUT_SIZE = 4;
const uint32_t num_train_layers = 10;
const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY;
const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM;
Expand Down Expand Up @@ -211,6 +209,9 @@ int entry(int argc, char **argv)
onert_micro::OMTrainingInterpreter train_interpreter;
train_interpreter.importTrainModel(circle_model.data(), config);

const uint32_t OUTPUT_SIZE = train_interpreter.getOutputSizeAt(0);
const uint32_t INPUT_SIZE = train_interpreter.getInputSizeAt(0);

// Temporary buffer to read input data from file using BATCH_SIZE
float training_input[BATCH_SIZE * INPUT_SIZE];
float training_target[BATCH_SIZE * OUTPUT_SIZE];
Expand Down Expand Up @@ -263,11 +264,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -276,10 +277,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down Expand Up @@ -335,11 +336,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -348,10 +349,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace train
namespace pal
{

// Note: dloss_dweight_data should be initialized
void inline FullyConnectedWeightGrad(
const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape,
const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data,
Expand All @@ -48,7 +49,7 @@ void inline FullyConnectedWeightGrad(
float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first];
for (uint32_t i = 0; i < accum_depth; ++i)
{
dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i];
dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i];
}
}

Expand Down
187 changes: 187 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H
#define ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H

#include "OMStatus.h"
#include "core/OMRuntimeShape.h"
#include "core/OMKernelType.h"

#include "PALUtils.h"
#include "ProcessBroadcastShapes.h"
#include "PALFullyConnectedWeightGrad.h"

namespace onert_micro
{
namespace train
{
namespace pal
{
namespace
{

void calculateGRUWeightGrads(
const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data,
const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data,
float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data,
const float *input_data, float *input_grad_data, float *state_grad_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_fc_shape,
const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, float *output_data, float *left_logistic_data,
float *left_mul_data, float *right_logistic_data, const float *right_mul_left_input_data,
const float *right_mul_right_input_data, float *tanh_data, const float *middle_mul_left_input,
const float *middle_mul_right_input, float *left_fc_output_grad_buffer,
float *right_fc_output_grad_buffer)
{
int num_elements = output_shape.flatSize();
for (int i = 0; i < num_elements; ++i)
{
// Middle Mul left input grad
float left_middle_mul = output_grad_data[i];
left_middle_mul *= middle_mul_right_input[i];

// Middle Mul right input grad
float right_middle_mul = output_grad_data[i];
right_middle_mul *= middle_mul_left_input[i];

// Tanh` = 1 / (cos(x) ^ 2)
float tanh_grad_value;
{
float tanh = std::tanh(tanh_data[i]);
tanh_grad_value = (1 - tanh * tanh) * right_middle_mul;
}

// Left mul
float left_mul_grad_value = output_grad_data[i] * output_data[i];

// Sub` = -1
// Left Logistic: Logistic` = (exp(-x) * (1 / (1 + exp(-x))) ^ 2)
float left_logistic_grad_value;
{
float log_value = (1 / (1 + std::exp(-left_logistic_data[i])));
left_logistic_grad_value =
log_value * (1 - log_value) * (left_middle_mul + left_mul_grad_value);
}

// Right mul left input
float right_mul_left_input = tanh_grad_value;
right_mul_left_input *= right_mul_right_input_data[i];

// Right mul right input
float right_mul_right_input = tanh_grad_value;
right_mul_right_input *= right_mul_left_input_data[i];

// Right logistic
float right_logistic_grad_value;
{
float log_value = (1 / (1 + std::exp(-right_logistic_data[i])));
right_logistic_grad_value = log_value * (1 - log_value) * right_mul_left_input;
}

// Left concatenation
left_fc_output_grad_buffer[i] = left_logistic_grad_value;
left_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value;
left_fc_output_grad_buffer[i + 2 * num_elements] = right_mul_right_input;

// Right concatenation
right_fc_output_grad_buffer[i] = left_logistic_grad_value;
right_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value;
right_fc_output_grad_buffer[i + 2 * num_elements] = tanh_grad_value;
}

// Left fc weight grad
FullyConnectedWeightGrad(left_fc_output_grad_buffer, output_fc_shape, output_data, output_shape,
weight_input_grad_data, weight_input_shape,
core::OpTrainableRankType::ALL);
// Right fc weight grad
FullyConnectedWeightGrad(right_fc_output_grad_buffer, output_fc_shape, input_data, input_shape,
weight_hidden_grad_data, weight_hidden_shape,
core::OpTrainableRankType::ALL);

// Set state grad to zero
std::memset(state_grad_data, 0, output_shape.flatSize() * sizeof(float));
}

} // namespace

OMStatus GRUWeightGrads(
const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data,
const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data,
float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data,
const float *input_data, float *input_grad_data, float *state_grad_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape,
const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer,
float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer)
{
const int32_t time = input_shape.dims(0);

// Init pointers to intermediate values
size_t offset = output_shape.flatSize();

size_t data_type_size = sizeof(float);
const int32_t num_of_intermediate_tensors = 9;
size_t time_offset = num_of_intermediate_tensors * offset;

core::OMRuntimeShape two_dim_input_shape(2);
auto dim_count = input_shape.dimensionsCount();
if (dim_count < 2)
return UnsupportedType;
two_dim_input_shape.setDim(0, input_shape.dims(dim_count - 2));
two_dim_input_shape.setDim(1, input_shape.dims(dim_count - 1));

core::OMRuntimeShape two_dim_output_shape(2);
dim_count = output_shape.dimensionsCount();
if (dim_count < 2)
return UnsupportedType;
two_dim_output_shape.setDim(0, output_shape.dims(dim_count - 2));
two_dim_output_shape.setDim(1, output_shape.dims(dim_count - 1));

std::memset(weight_input_grad_data, 0, output_shape.flatSize() * sizeof(float) * time);
std::memset(weight_hidden_grad_data, 0, input_shape.dims(2) * sizeof(float) * time);

for (int i = 0; i < time; ++i)
{
float *output_data = intermediate_buffer;
float *left_logistic_data = output_data + offset;
float *left_mul_data = left_logistic_data + offset;
float *right_logistic_data = left_mul_data + offset;
float *right_mul_left_input_data = right_logistic_data + offset;
float *right_mul_right_input_data = right_mul_left_input_data + offset;
float *tanh_data = right_mul_right_input_data + offset;
float *middle_mul_left_input = tanh_data + offset;
float *middle_mul_right_input = middle_mul_left_input + offset;

calculateGRUWeightGrads(
output_grad_data, weight_input_data, weight_input_grad_data, weight_hidden_data,
weight_hidden_grad_data, bias_input_data, bias_input_grad_data, bias_hidden_data,
bias_hidden_grad_data, input_data, input_grad_data, state_grad_data, two_dim_input_shape,
output_shape_fc, two_dim_output_shape, weight_input_shape, weight_hidden_shape, output_data,
left_logistic_data, left_mul_data, right_logistic_data, right_mul_left_input_data,
right_mul_right_input_data, tanh_data, middle_mul_left_input, middle_mul_right_input,
left_fc_output_grad_buffer, right_fc_output_grad_buffer);
input_data += input_shape.dims(2);
intermediate_buffer += time_offset;
}
return Ok;
}

} // namespace pal
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax)
REGISTER_TRAIN_KERNEL(RESHAPE, Reshape)
REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D)
REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D)
REGISTER_TRAIN_KERNEL(GRU, GRU)
REGISTER_TRAIN_KERNEL(STRIDED_SLICE, StridedSlice)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode)
{
case circle::BuiltinOperator_FULLY_CONNECTED:
case circle::BuiltinOperator_CONV_2D:
case circle::BuiltinOperator_GRU:
return true;
default:
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode)
{
case circle::BuiltinOperator_FULLY_CONNECTED:
case circle::BuiltinOperator_CONV_2D:
case circle::BuiltinOperator_GRU:
return true;
default:
return false;
Expand Down
7 changes: 6 additions & 1 deletion onert-micro/onert-micro/src/train/OMBackpropExecute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut
args.is_last_layer = false;
}

if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
if (trainable_ops_config.empty())
{
args.is_trainable_layer = true;
args.train_rank_type = core::OpTrainableRankType::ALL;
}
else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
{
args.is_trainable_layer = true;
args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]);
Expand Down
3 changes: 3 additions & 0 deletions onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE
weight_shape = dynamic_shapes;

// 2. Calculate weight gradient
// Init weight grads with zeros
std::memset(dloss_dweight_data, 0,
output_shape.dims(1) * input_shape.dims(1) * sizeof(float));
pal::FullyConnectedWeightGrad(
core::utils::castInputData<float>(dloss_doutput_data), output_shape,
core::utils::castInputData<float>(input_data), input_shape,
Expand Down
Loading

0 comments on commit aa9e945

Please sign in to comment.