diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/Interpreter.h b/onert-micro/luci-interpreter/include/luci_interpreter/Interpreter.h index 337d52831fc..d2afd6ea235 100644 --- a/onert-micro/luci-interpreter/include/luci_interpreter/Interpreter.h +++ b/onert-micro/luci-interpreter/include/luci_interpreter/Interpreter.h @@ -32,6 +32,14 @@ namespace luci_interpreter { +#ifdef ENABLE_TRAINING +namespace training +{ +class TrainingOnertMicro; +} // namespace training + +#endif // ENABLE_TRAINING + class Interpreter { public: @@ -56,6 +64,10 @@ class Interpreter void interpret(); +#ifdef ENABLE_TRAINING + friend class training::TrainingOnertMicro; +#endif // ENABLE_TRAINING + private: // _default_memory_manager should be before _runtime_module due to // the order of deletion in the destructor diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/TrainingOnertMicro.h b/onert-micro/luci-interpreter/include/luci_interpreter/TrainingOnertMicro.h new file mode 100644 index 00000000000..f1ca77675a3 --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/TrainingOnertMicro.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_TRAINING + +#ifndef LUCI_INTERPRETER_TRAINING_ONERT_MICRO_H +#define LUCI_INTERPRETER_TRAINING_ONERT_MICRO_H + +#include "luci_interpreter/TrainingSettings.h" +#include "luci_interpreter/Interpreter.h" +#include "core/TrainingModule.h" + +namespace luci_interpreter +{ +namespace training +{ + +class TrainingOnertMicro +{ +public: + explicit TrainingOnertMicro(Interpreter *interpreter, TrainingSettings &settings); + + ~TrainingOnertMicro(); + + Status enableTrainingMode(); + + Status disableTrainingMode(bool resetWeights = false); + + Status train(uint32_t number_of_train_samples, const uint8_t *train_data, + const uint8_t *label_train_data); + + Status test(uint32_t number_of_train_samples, const uint8_t *test_data, + const uint8_t *label_test_data, void *metric_value_result); + +private: + Interpreter *_interpreter; + + TrainingSettings &_settings; + + TrainingModule _module; + + bool _is_training_mode; +}; + +} // namespace training +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TRAINING_ONERT_MICRO_H + +#endif // ENABLE_TRAINING diff --git a/onert-micro/luci-interpreter/pal/mcu/KernelsToTrain.lst b/onert-micro/luci-interpreter/pal/mcu/KernelsToTrain.lst index e69de29bb2d..1025141ce2c 100644 --- a/onert-micro/luci-interpreter/pal/mcu/KernelsToTrain.lst +++ b/onert-micro/luci-interpreter/pal/mcu/KernelsToTrain.lst @@ -0,0 +1 @@ +REGISTER_TRAIN_KERNEL(FULLY_CONNECTED, FullyConnected) diff --git a/onert-micro/luci-interpreter/src/CMakeLists.txt b/onert-micro/luci-interpreter/src/CMakeLists.txt index 75d6836df2f..d890d3becd3 100644 --- a/onert-micro/luci-interpreter/src/CMakeLists.txt +++ b/onert-micro/luci-interpreter/src/CMakeLists.txt @@ -31,7 +31,7 @@ target_include_directories(${LUCI_INTERPRETER_CORE} PUBLIC ${LUCI_INTERPRETER_KE message(STATUS "LUCI INTERPTER INITALIZED") set(SOURCES - "${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/Interpreter.h" Interpreter.cpp) + "${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/Interpreter.h" Interpreter.cpp TrainingOnertMicro.cpp) add_library(${LUCI_INTERPRETER_BINARY} STATIC ${SOURCES}) diff --git a/onert-micro/luci-interpreter/src/TrainingOnertMicro.cpp b/onert-micro/luci-interpreter/src/TrainingOnertMicro.cpp new file mode 100644 index 00000000000..a485697a213 --- /dev/null +++ b/onert-micro/luci-interpreter/src/TrainingOnertMicro.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_TRAINING + +#include "luci_interpreter/TrainingOnertMicro.h" +#include "core/TrainingModule.h" + +#include + +namespace luci_interpreter +{ + +namespace training +{ + +namespace +{ + +float calculateMSEError(const float *predicted_values, const float *target_values, + const uint32_t output_size) +{ + const uint32_t output_number_values = output_size / sizeof(float); + float result = 0.0f; + for (int i = 0; i < output_number_values; ++i) + { + result += std::pow(predicted_values[i] - target_values[i], 2); + } + + return result / output_number_values; +} + +float calculateMAEError(const float *predicted_values, const float *target_values, + const uint32_t output_size) +{ + const uint32_t output_number_values = output_size / sizeof(float); + float result = 0.0f; + for (int i = 0; i < output_number_values; ++i) + { + result += std::abs(predicted_values[i] - target_values[i]); + } + + return result / output_number_values; +} + +Status calculateError(const uint8_t *predicted_value, const uint8_t *target_value, void *result, + const uint32_t output_size, MetricsTypeEnum error_type) +{ + switch (error_type) + { + case MSE: + { + float *result_float = reinterpret_cast(result); + *result_float += + calculateMSEError(reinterpret_cast(predicted_value), + reinterpret_cast(target_value), output_size); + break; + } + case MAE: + { + float *result_float = reinterpret_cast(result); + *result_float += + calculateMAEError(reinterpret_cast(predicted_value), + reinterpret_cast(target_value), output_size); + break; + } + default: + { + return Error; + } + } + + return Ok; +} + +} // namespace + +TrainingOnertMicro::TrainingOnertMicro(Interpreter *interpreter, TrainingSettings &settings) + : _interpreter(interpreter), _settings(settings), _is_training_mode(false), + _module(&interpreter->_runtime_module) +{ + // Do nothing +} + +Status TrainingOnertMicro::enableTrainingMode() +{ + if (_is_training_mode) + { + return DoubleTrainModeError; + } + + const Status status = _module.enableTrainingMode(_settings, &_interpreter->_memory_manager); + + if (status != Ok) + assert("Some error during enabling training mode"); + + _is_training_mode = true; + + return status; +} + +Status TrainingOnertMicro::disableTrainingMode(bool resetWeights) +{ + if (_is_training_mode == false) + { + return Ok; + } + + const Status status = _module.disableTrainingMode(resetWeights); + + if (status != Ok) + assert("Some error during disabling training mode"); + + _is_training_mode = false; + + return status; +} + +Status TrainingOnertMicro::train(uint32_t number_of_train_samples, const uint8_t *train_data, + const uint8_t *label_train_data) +{ + if (_is_training_mode == false) + return EnableTrainModeError; + + const uint32_t batch_size = _settings.batch_size; + + const uint32_t num_inferences = number_of_train_samples / batch_size; + + const uint32_t remains = number_of_train_samples % batch_size; + + const uint32_t epochs = _settings.number_of_epochs; + + const int32_t input_tensor_size = _interpreter->getInputDataSizeByIndex(0); + const int32_t output_tensor_size = _interpreter->getOutputDataSizeByIndex(0); + + const uint8_t *cur_train_data = train_data; + const uint8_t *cur_label_train_data = label_train_data; + + for (uint32_t epoch = 0; epoch < epochs; ++epoch) + { + for (uint32_t infer = 0; infer < num_inferences; ++infer) + { + for (uint32_t batch = 0; batch < batch_size; ++batch) + { + _interpreter->allocateAndWriteInputTensor(0, cur_train_data, input_tensor_size); + + _interpreter->interpret(); + + _module.computeGradients(_settings, cur_label_train_data); + cur_train_data += input_tensor_size; + cur_label_train_data += output_tensor_size; + } + + _module.updateWeights(_settings); + } + cur_train_data = train_data; + cur_label_train_data = label_train_data; + } + + return Ok; +} + +Status TrainingOnertMicro::test(uint32_t number_of_train_samples, const uint8_t *test_data, + const uint8_t *label_test_data, void *metric_value_result) +{ + const int32_t input_tensor_size = _interpreter->getInputDataSizeByIndex(0); + const int32_t output_tensor_size = _interpreter->getOutputDataSizeByIndex(0); + + const uint8_t *cur_test_data = test_data; + const uint8_t *cur_label_test_data = label_test_data; + + switch (_settings.metric) + { + case MSE: + case MAE: + { + float *result_float = reinterpret_cast(metric_value_result); + *result_float = 0.0f; + break; + } + default: + { + return Error; + } + } + + for (uint32_t sample = 0; sample < number_of_train_samples; ++sample) + { + _interpreter->allocateAndWriteInputTensor(0, cur_test_data, input_tensor_size); + + _interpreter->interpret(); + + const uint8_t *output_data = _interpreter->readOutputTensor(0); + + Status status = calculateError(output_data, cur_label_test_data, metric_value_result, + output_tensor_size, _settings.metric); + + if (status != Ok) + return status; + + cur_test_data += input_tensor_size; + cur_label_test_data += output_tensor_size; + } + + switch (_settings.metric) + { + case MSE: + case MAE: + { + float *result_float = reinterpret_cast(metric_value_result); + *result_float /= number_of_train_samples; + break; + } + default: + { + return Error; + } + } + + return Ok; +} + +TrainingOnertMicro::~TrainingOnertMicro() { disableTrainingMode(); } + +} // namespace training + +} // namespace luci_interpreter + +#endif // ENABLE_TRAINING diff --git a/onert-micro/luci-interpreter/src/kernels/Builders.h b/onert-micro/luci-interpreter/src/kernels/Builders.h index addb0efa921..98dca06cf10 100644 --- a/onert-micro/luci-interpreter/src/kernels/Builders.h +++ b/onert-micro/luci-interpreter/src/kernels/Builders.h @@ -46,6 +46,7 @@ using BaseRuntimeGraph = RuntimeGraph; #undef REGISTER_KERNEL +#ifdef ENABLE_TRAINING namespace training { #define REGISTER_TRAIN_KERNEL(builtin_operator, name) \ @@ -58,6 +59,8 @@ namespace training #undef REGISTER_TRAIN_KERNEL } // namespace training +#endif // ENABLE_TRAINING + } // namespace luci_interpreter #endif // LUCI_INTERPRETER_KERNELS_NODES_BUILDERS_H