From f49786a24a1f15ac66e7ba7966a482cea8942644 Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Fri, 21 Jun 2024 16:50:02 +0900 Subject: [PATCH] [onert-micro] onert-micro-dev - onert-micro api and its implementation ONE-DCO-1.0-Signed-off-by: Chunseok Lee --- onert-micro/onert-micro/include/onert-micro.h | 359 ++++++++++++++ onert-micro/onert-micro/src/CMakeLists.txt | 6 + .../onert-micro/src/api/CMakeLists.txt | 14 + .../onert-micro/src/api/onert-micro.cpp | 436 ++++++++++++++++++ 4 files changed, 815 insertions(+) create mode 100644 onert-micro/onert-micro/include/onert-micro.h create mode 100644 onert-micro/onert-micro/src/api/CMakeLists.txt create mode 100644 onert-micro/onert-micro/src/api/onert-micro.cpp diff --git a/onert-micro/onert-micro/include/onert-micro.h b/onert-micro/onert-micro/include/onert-micro.h new file mode 100644 index 00000000000..e340e817561 --- /dev/null +++ b/onert-micro/onert-micro/include/onert-micro.h @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _ONERT_MICRO_H_ +#define _ONERT_MICRO_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * typical training flow in onert-micro + * + * 1. load model or checkpoint + * 1-1. (optional) configure training options + * 2. feed training input / output(e.g. label) data (cf. unit of a step) + * 3. train a step + * 4. check loss + * 4-0. save checkpoint for recovery/resume training + * 4-1. no more traning -> go to 5 + * 4-2. more training -> go to 2 + * 5. save current state to inference model + * 6. inference with inference model +// sample example +// 0. create context +nnfw_session *session; +nnfw_create_session(&session); +// 1. load model (and checkpoint if continue training) +nnfw_load_model_from_file(session, MODEL_PATH); +// 1-1. (optional, TBD) configure training options +nnfw_load_ckpt_from_file(session, CKPT_PATH); +nnfw_train_prepare(session); +float training_input[BATCH_SIZE*INPUT_SIZE]; +float training_label[BATCH_SIZE*OUTPUT_SIZE]; +// main training loop +for(int epoch=0; epoch < NUM_EPOCHS; epoch++) { + for(int step=0; step < NUM_BATCHES ; step++) { + // prepare this steps's intput/label + memcpy(training_input, train_input_data + THIS_BATCH_OFFSET, BATCH_SIZE*INPUT_SIZE); + memcpy(training_output, train_output_data + THIS_BATCH_OFFSET, BATCH_SIZE*OUTPUT_SIZE); + // 2. feed training input / expected output + nnfw_train_set_input(session, 0 , training_input, NULL); + nnfw_train_set_expected(session, 0 , training_output, NULL); + // 3. train a step + nnfw_train(session); + } + // 4. check loss + float loss; + nnfw_train_get_loss(ctx, 0, &loss); + if(loss > TARGET_LOSS) { + nnfw_train_save_as_checkpoint(ctx, CKPT_PATH); + } + else { + nnfw_train_export_circle(ctx, CIRCLE_PATH); + } +} +*/ + +typedef struct nnfw_session nnfw_session; + +typedef enum +{ + /** A tensor of 32 bit floating point */ + NNFW_TYPE_TENSOR_FLOAT32 = 0, + /** A tensor of 32 bit signed integer */ + NNFW_TYPE_TENSOR_INT32 = 1, +} NNFW_TYPE; + +/** + * @brief Result values returned from a call to an API function + */ +typedef enum +{ + /** Successful */ + NNFW_STATUS_NO_ERROR = 0, + /** + * An error code for general use. + * Mostly used when there is no specific value for that certain situation. + */ + NNFW_STATUS_ERROR = 1, + /** Unexpected null argument is given. */ + NNFW_STATUS_UNEXPECTED_NULL = 2, + /** When a function was called but it is not valid for the current session state. */ + NNFW_STATUS_INVALID_STATE = 3, + /** When it is out of memory */ + NNFW_STATUS_OUT_OF_MEMORY = 4, + /** When it was given an insufficient output buffer */ + NNFW_STATUS_INSUFFICIENT_OUTPUT_SIZE = 5, + /** When API is deprecated */ + NNFW_STATUS_DEPRECATED_API = 6, +} NNFW_STATUS; + +/** + * @brief Maximum rank expressible with nnfw + */ +#define NNFW_MAX_RANK (6) + +/** + * @brief tensor info describes the type and shape of tensors + * + *

This structure is used to describe input and output tensors. + * Application can get input and output tensor type and shape described in model by using + * {@link nnfw_input_tensorinfo} and {@link nnfw_output_tensorinfo} + * + *

Maximum rank is 6 (NNFW_MAX_RANK). And tensor's dimension value is filled in 'dims' field from + * index 0. + * For example, if tensor's rank is 4, + * application can get dimension value from dims[0], dims[1], dims[2], and dims[3] + */ +typedef struct nnfw_tensorinfo +{ + /** The data type */ + NNFW_TYPE dtype; + /** The number of dimensions (rank) */ + int32_t rank; + /** + * The dimension of tensor. + * Maximum rank is 6 (NNFW_MAX_RANK). + */ + int32_t dims[NNFW_MAX_RANK]; +} nnfw_tensorinfo; + +////////////////////////////////////////////// +// Essential APIs for training +////////////////////////////////////////////// +typedef enum +{ + NNFW_TRAIN_LOSS_UNDEFINED = 0, + NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 1, + NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 2, +} NNFW_TRAIN_LOSS; + +typedef enum +{ + /** Undefined */ + NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED = 0, + /** Scalar sum divided by number of elements in losses */ + NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 1, + /** Scalar sum of weighted losses */ + NNFW_TRAIN_LOSS_REDUCTION_SUM = 2, +} NNFW_TRAIN_LOSS_REDUCTION; + +typedef enum +{ + NNFW_TRAIN_OPTIMIZER_UNDEFINED = 0, + NNFW_TRAIN_OPTIMIZER_SGD = 1, + NNFW_TRAIN_OPTIMIZER_ADAM = 2, +} NNFW_TRAIN_OPTIMIZER; + +typedef struct nnfw_loss_info +{ + NNFW_TRAIN_LOSS loss; + NNFW_TRAIN_LOSS_REDUCTION reduction_type; +} nnfw_loss_info; + +typedef struct nnfw_adam_option +{ + float beta; + float beta2; + float epsilon; +} nnfw_adam_option; + +/** + * @brief Maximum numer of trainable operations + */ +#define NNFW_TRAINABLE_OPS_MAX_SIZE (256) + +/** + * @brief Training information to prepare training + * @todo Add more training information + * (e.g. optimizer, loss function, ...) + */ +typedef struct nnfw_train_info +{ + /** Learning rate */ + float learning_rate = 0.001f; + /** Batch size */ + uint32_t batch_size = 1; + /** loss info */ + nnfw_loss_info loss_info{.loss = NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY, + .reduction_type = NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE}; + /** optimizer type */ + NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_ADAM; + + uint32_t num_trainble_ops = 0; + + nnfw_adam_option adam_opt{.beta = 0.9f, + + .beta2 = 0.999f, + .epsilon = 1e-7f}; +} nnfw_train_info; + +/** + * @brief Set training information + * @note This function should be called after calling {@link nnfw_load_model_from_file} + * and before calling {@link nnfw_train_prepare} + * + * @param[in] session The session to be set training information + * @param[in] info The training information + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *info); + +/** + * @brief Create a new session instance. + * + *

This only creates a session. + * Model is loaded after {@link nnfw_load_model_from_file} is invoked. + * And inference is performed after {@link nnfw_run} is invoked. + * + *

{@link nnfw_close_session} should be called once + * if session is no longer needed + * + * @param[out] session The session to be created + * @return NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_create_session(nnfw_session **session); + +/** + * @brief Close a session instance + * + * After called, access to closed session by application will be invalid + * + * @param[in] session The session to be closed + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_close_session(nnfw_session *session); + +/** + * @brief Load model from nnpackage file or directory + * + * The length of \p package_file_path must not exceed 1024 bytes including zero at the end. + * + * @param[in] session nnfw_session loading the given nnpackage file/dir + * @param[in] package_file_path Path to the nnpackage file or unzipped directory to be loaded + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *package_file_path); + +/** + * @brief Prepare session to be ready for training + * @note The session will be entered into training mode + * + * If training info is NOT set in session, this function returns @c NNFW_STATUS_ERROR . + * You should set training info using {@link nnfw_train_set_traininfo}. + * + * @param[in] session The session to be prepared for training + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_prepare(nnfw_session *session); + +/** + * @brief Train the model + * @note This function should be called after {@link nnfw_train_set_input} and + * {@link nnfw_train_set_expected} for each input and expected output + * + * In order to use \p update_weights as false, it should be called after + * {@link nnfw_train_set_output}. + * + * @param[in] session The session to be trained + * @param[in] update_weights If true, update weights of the model + * If false, do not update weights of the model (for validation) + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights); + +/** + * @brief Export current training model into circle model + * @note This function should be called on training mode + * This function should be called after {@link nnfw_train} + * + * @param[in] session The session to export inference model + * @param[in] path The path to export inference model + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path); + +NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path); +NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path); + +/** + * @brief Set training input + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param[in] session The session to be set training inputs and expected model outputs + * @param[in] index The index of training input + * @param[in] input The input buffers for training + * @param[in] input_info The shape and type of input buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, void *input, + const nnfw_tensorinfo *input_info); + +/** + * @brief Set training expected output + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param session The session to be set training inputs and expected model outputs + * @param index The index of training expected output + * @param expected The expected buffers for training + * @param expected_info The shape and type of expected buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, void *expected, + const nnfw_tensorinfo *expected_info); + +/** + * @brief Get loss value for expected output + * @note This function should be called after {@link nnfw_train} + * + * @param[in] session The session to get loss value + * @param[in] index The index of loss value [0, number of expected outputs) + * @param[out] loss The loss value + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss); + +/** + * @brief Set training output buffer + * + * This function must be called after {@link nnfw_train_prepare}, \p buffer given to this function + * can be reused for training. \p length must be greater or equal than the operand requires. + * An output operand can have unspecified shape and deduced dynamically during the execution. You + * must provide \p buffer large enough. + * + * @param[in] session Session from inference output is to be extracted + * @param[in] index Index of output to be set (0-indexed) + * @param[in] type Type of the output + * @param[out] buffer Raw buffer for output + * @param[in] length Size of bytes of output buffer + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_output(nnfw_session *session, uint32_t index, NNFW_TYPE type, + void *buffer, size_t length); + +#ifdef __cplusplus +} +#endif + +#endif //_ONERT_MICRO_H_ diff --git a/onert-micro/onert-micro/src/CMakeLists.txt b/onert-micro/onert-micro/src/CMakeLists.txt index 903e93e7e5a..6c2d7a8a733 100644 --- a/onert-micro/onert-micro/src/CMakeLists.txt +++ b/onert-micro/onert-micro/src/CMakeLists.txt @@ -13,6 +13,8 @@ set(OM_INCLUDE_OPTIMIZE_DIR "${OM_INCLUDE_DIR}/optimize") #define train path set(OM_SOURCE_TRAIN_DIR "${OM_SOURCE_DIR}/train") set(OM_INCLUDE_TRAIN_DIR "${OM_INCLUDE_DIR}/train") +#define train path +set(OM_SOURCE_DEV_DIR "${OM_SOURCE_DIR}/api") #OM_Interpreter lib binary name set(OM_INTERPRETER_LIB "onert_micro_interpreter") @@ -30,6 +32,8 @@ set(OM_OPTIMIZE_LIB "onert_micro_optimize${OM_SUFFIX}") set(OM_PAL_LIB "onert_micro_pal${OM_SUFFIX}") #Train lib binary name set(OM_TRAIN_LIB "onert_micro_train${OM_SUFFIX}") +#dev name +set(OM_DEV_LIB "onert_micro_dev${OM_SUFFIX}") message(STATUS "ONERT MICRO BEGIN") @@ -41,6 +45,8 @@ add_subdirectory(${OM_SOURCE_IMPORT_DIR}) add_subdirectory(${OM_SOURCE_EXECUTE_DIR}) #build optimize lib add_subdirectory(${OM_SOURCE_OPTIMIZE_DIR}) +#build dev lib +add_subdirectory(${OM_SOURCE_DEV_DIR}) target_link_libraries(${OM_CORE_LIB} PUBLIC ${OM_CIRCLE_SCHEMA}) target_link_libraries(${OM_CORE_LIB} PUBLIC ${OM_IMPORT_LIB}) diff --git a/onert-micro/onert-micro/src/api/CMakeLists.txt b/onert-micro/onert-micro/src/api/CMakeLists.txt new file mode 100644 index 00000000000..98baf6e0938 --- /dev/null +++ b/onert-micro/onert-micro/src/api/CMakeLists.txt @@ -0,0 +1,14 @@ +message(STATUS "ONERT MICRO DEV BUILD BEGIN") + +set(SOURCES + onert-micro.cpp) + +add_library(${OM_DEV_LIB} STATIC ${SOURCES}) +target_compile_options(${OM_DEV_LIB} PRIVATE "-fexceptions") +target_link_libraries(${OM_DEV_LIB} PUBLIC ${OM_TRAININFO_SCHEMA}) +target_include_directories(${OM_DEV_LIB} PUBLIC "${OM_INCLUDE_DIR}") +target_link_libraries(${OM_DEV_LIB} PUBLIC ${OM_INTERPRETER_LIB}) +target_link_libraries(${OM_DEV_LIB} PUBLIC ${OM_TRAINING_INTERPRETER_LIB}) +target_link_libraries(${OM_DEV_LIB} PUBLIC onert_micro_coverage) + +message(STATUS "ONERT MICRO DEV BUILD FINISHED") diff --git a/onert-micro/onert-micro/src/api/onert-micro.cpp b/onert-micro/onert-micro/src/api/onert-micro.cpp new file mode 100644 index 00000000000..fabaeae7b37 --- /dev/null +++ b/onert-micro/onert-micro/src/api/onert-micro.cpp @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "OMTrainingInterpreter.h" +#include "onert-micro.h" +#include +#include + +#define NNFW_RETURN_ERROR_IF_NULL(p) \ + do \ + { \ + if ((p) == NULL) \ + return NNFW_STATUS_UNEXPECTED_NULL; \ + } while (0) + +// helper for file processing +using DataBuffer = std::vector; + +DataBuffer readFile(const char *path) +{ + std::ifstream file(path, std::ios::binary | std::ios::in); + if (!file.good()) + { + std::string errmsg = "Failed to open file"; + throw std::runtime_error(errmsg.c_str()); + } + + file.seekg(0, std::ios::end); + auto fileSize = file.tellg(); + file.seekg(0, std::ios::beg); + + // reserve capacity + DataBuffer model_data(fileSize); + + // read the data + file.read(model_data.data(), fileSize); + if (file.fail()) + { + std::string errmsg = "Failed to read file"; + throw std::runtime_error(errmsg.c_str()); + } + + return model_data; +} + +struct nnfw_session +{ +private: +public: + /** + * @brief Factory method. It creates and initialize nnfw_session + * + * @note Use factory instead of constructor to get status + */ + static NNFW_STATUS create(nnfw_session **session); + +private: + nnfw_session(); + +public: + ~nnfw_session(); + + NNFW_STATUS load_model_from_file(const char *package_file_path); + + NNFW_STATUS train_set_traininfo(const nnfw_train_info *info); + NNFW_STATUS train_prepare(); + NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_set_input(uint32_t index, void *input); + NNFW_STATUS train_set_expected(uint32_t index, void *expected); + NNFW_STATUS train_set_output(uint32_t index, NNFW_TYPE type, void *buffer, size_t length); + NNFW_STATUS train_run(bool update_weights); + NNFW_STATUS train_get_loss(uint32_t index, float *loss); + NNFW_STATUS train_export_circle(const char *path); + + NNFW_STATUS train_export_checkpoint(const char *path); + NNFW_STATUS train_import_checkpoint(const char *path); + +private: + uint32_t getInputSize(); + uint32_t getOutputSize(); + NNFW_STATUS loadTrainingInfo(char *buf_ptr); + NNFW_STATUS loadOptimizerInfo(const circle::ModelTraining *circle_model); + NNFW_STATUS loadLossInfo(const circle::ModelTraining *circle_model); + NNFW_STATUS loadTrainableOps(const circle::ModelTraining *circle_model, int num_ops); + +private: + onert_micro::OMTrainingInterpreter *_train_interpreter; + onert_micro::OMConfig _config; + DataBuffer _model_buf; + std::string _model_path; + uint8_t *outputbuf; +}; + +nnfw_session::nnfw_session() : _train_interpreter{new onert_micro::OMTrainingInterpreter()} +{ + // TODO: Remove after implementing train_set_traininfo + // Set user defined training settings + const uint32_t training_epochs = 10; + const float learning_rate = 0.001f; + const uint32_t num_train_layers = 10; + const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY; + const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM; + const float beta = 0.9; + const float beta_squares = 0.999; + const float epsilon = 1e-07; + + _config.train_mode = true; + { + onert_micro::OMTrainingContext train_context; + train_context.batch_size = 32; + train_context.num_of_train_layers = num_train_layers; + train_context.learning_rate = learning_rate; + train_context.loss = loss; + train_context.optimizer = train_optim; + train_context.beta = beta; + train_context.beta_squares = beta_squares; + train_context.epsilon = epsilon; + train_context.num_step = 0; + + _config.training_context = train_context; + } + + outputbuf = nullptr; +} + +NNFW_STATUS nnfw_session::create(nnfw_session **session) +{ + if (session == nullptr) + return NNFW_STATUS_UNEXPECTED_NULL; + + auto new_session = std::unique_ptr(new nnfw_session()); + *session = new_session.release(); + + if (*session == nullptr) + { + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +nnfw_session::~nnfw_session() { delete _train_interpreter; } + +NNFW_STATUS nnfw_session::loadOptimizerInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + const circle::Optimizer circle_opt = circle_model->optimizer(); + + switch (circle_opt) + { + case circle::Optimizer_SGD: + _config.training_context.optimizer = onert_micro::SGD; + _config.training_context.learning_rate = + circle_model->optimizer_opt_as_SGDOptions()->learning_rate(); + break; + case circle::Optimizer_ADAM: + _config.training_context.optimizer = onert_micro::ADAM; + _config.training_context.learning_rate = + circle_model->optimizer_opt_as_AdamOptions()->learning_rate(); + _config.training_context.beta = circle_model->optimizer_opt_as_AdamOptions()->beta_1(); + _config.training_context.beta_squares = + circle_model->optimizer_opt_as_AdamOptions()->beta_2(); + _config.training_context.epsilon = circle_model->optimizer_opt_as_AdamOptions()->epsilon(); + break; + default: + std::cerr << "unknown optimzer" << std::endl; + return NNFW_STATUS_ERROR; + } + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::loadLossInfo(const circle::ModelTraining *circle_model) +{ + assert(circle_model != nullptr); + + const circle::LossFn circle_loss = circle_model->lossfn(); + + switch (circle_loss) + { + case circle::LossFn::LossFn_CATEGORICAL_CROSSENTROPY: + _config.training_context.loss = onert_micro::CROSS_ENTROPY; + break; + case circle::LossFn::LossFn_MEAN_SQUARED_ERROR: + _config.training_context.loss = onert_micro::MSE; + break; + case circle::LossFn::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY: + // TODO enable this conversion after core support sparse_categorial_crossentropy + std::cerr << "'sparse_categorical_crossentropy' is not supported yet" << std::endl; + return NNFW_STATUS_ERROR; + default: + std::cerr << "unknown loss function" << std::endl; + return NNFW_STATUS_ERROR; + } + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::loadTrainableOps(const circle::ModelTraining *circle_model, int num_ops) +{ + assert(circle_model != nullptr); + + auto ops_list = circle_model->trainable_ops(); + if (ops_list != nullptr) + _config.training_context.num_of_train_layers = + num_ops - ops_list->data()[0]; // simply assume ops[0] is the least node number + else + _config.training_context.num_of_train_layers = num_ops; + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::loadTrainingInfo(char *buf) +{ + auto model = circle::GetModel(buf); + auto num_ops = model->subgraphs()->Get(0)->operators()->size(); + // Load Metadata + auto const metadata_list = model->metadata(); + const uint8_t *data = nullptr; + if (metadata_list != nullptr) + { + for (uint32_t i = 0; i < metadata_list->size(); ++i) + { + const auto metadata = metadata_list->Get(i); + if (strcmp(metadata->name()->c_str(), "CIRCLE_TRAINING") != 0) + continue; + data = (model->buffers()->Get(metadata->buffer()))->data()->data(); + } + const circle::ModelTraining *traininfo_model = + circle::GetModelTraining(static_cast(data)); + _config.training_context.batch_size = traininfo_model->batch_size(); + loadOptimizerInfo(traininfo_model); + loadLossInfo(traininfo_model); + loadTrainableOps(traininfo_model, num_ops); + } + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::load_model_from_file(const char *file_path) +{ + _model_buf = readFile(file_path); + _config.model_ptr = _model_buf.data(); + _config.model_size = _model_buf.size(); + // load training info + loadTrainingInfo(_config.model_ptr); + // TODO: this import should start on nnfw_prepare if inference_interpreter is introduced + _train_interpreter->importTrainModel(_config.model_ptr, _config); + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_prepare() +{ + // TODO: Implement remaining jobs if inference_interpreter is introduced + // maybe interpreter initialization ? + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_run(bool update_weights) +{ + if (update_weights) + { + // TOOD: micro support update_weights ??? + // Here we use this flag for distinguish inference and train in trainaing interpreter + _train_interpreter->trainSingleStep(_config); + _config.training_context.num_epoch = + _config.training_context.num_step / _config.training_context.batch_size + 1; + } + else + { + // TODO: support multiple input/output + assert(outputbuf != nullptr); + _train_interpreter->allocateInputs(); + float *allocated_input_data = (float *)_train_interpreter->getInputDataAt(0); + float *user_input_data = (float *)_train_interpreter->getInputData(0); + memcpy(allocated_input_data, user_input_data, + sizeof(float) * _train_interpreter->getInputSizeAt(0)); + _train_interpreter->run(); + float *calculated_ptr = (float *)_train_interpreter->getOutputDataAt(0); + memcpy(outputbuf, calculated_ptr, sizeof(float) * _train_interpreter->getOutputSizeAt(0)); + _train_interpreter->reset(); + } + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_export_circle(const char *path) +{ + _train_interpreter->saveModel(_config, path); + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_export_checkpoint(const char *path) +{ + _train_interpreter->saveCheckpoint(_config, path); + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path) +{ + _train_interpreter->loadCheckpoint(_config, path); + return NNFW_STATUS_NO_ERROR; +} + +// TODO: onert's this function takes const type input +NNFW_STATUS nnfw_session::train_set_input(uint32_t index, void *input) +{ + _train_interpreter->setInput((uint8_t *)input, index); + return NNFW_STATUS_NO_ERROR; +} + +// TODO: onert's this function takes const type expected +NNFW_STATUS nnfw_session::train_set_expected(uint32_t index, void *expected) +{ + _train_interpreter->setTarget((uint8_t *)expected, index); + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_set_output(uint32_t index, NNFW_TYPE type, void *buffer, + size_t length) +{ + outputbuf = (uint8_t *)buffer; + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_set_traininfo(const nnfw_train_info *info) +{ + _config.training_context.learning_rate = info->learning_rate; + _config.training_context.batch_size = info->batch_size; + _config.training_context.optimizer = + (info->opt == NNFW_TRAIN_OPTIMIZER_ADAM) ? onert_micro::ADAM : onert_micro::SGD; + _config.training_context.beta = info->adam_opt.beta; + _config.training_context.beta_squares = info->adam_opt.beta2; + _config.training_context.beta = info->adam_opt.epsilon; + _config.training_context.num_of_train_layers = info->num_trainble_ops; + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_get_loss(uint32_t index, float *loss) +{ + onert_micro::OMMetrics m; + switch (_config.training_context.loss) + { + case onert_micro::CROSS_ENTROPY: + m = onert_micro::CROSS_ENTROPY_METRICS; + break; + default: + m = onert_micro::CROSS_ENTROPY_METRICS; + break; + } + + _train_interpreter->evaluateMetric(m, reinterpret_cast(loss), + _config.training_context.batch_size); + return NNFW_STATUS_NO_ERROR; +} + +// onert-micr.h implementation + +NNFW_STATUS nnfw_create_session(nnfw_session **session) { return nnfw_session::create(session); } + +NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *package_file_path) +{ + return session->load_model_from_file(package_file_path); +} + +NNFW_STATUS nnfw_train_prepare(nnfw_session *session) { return session->train_prepare(); } + +NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights) +{ + return session->train_run(update_weights); +} + +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path) +{ + return session->train_export_circle(path); +} + +NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path) +{ + return session->train_export_checkpoint(path); +} + +NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path) +{ + return session->train_import_checkpoint(path); +} + +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, void *input, + const nnfw_tensorinfo *input_info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_input(index, input); +} + +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, void *expected, + const nnfw_tensorinfo *expected_info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_expected(index, expected); +} + +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_get_loss(index, loss); +} + +NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_traininfo(info); +} + +NNFW_STATUS nnfw_train_set_output(nnfw_session *session, uint32_t index, NNFW_TYPE type, + void *buffer, size_t length) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_output(index, type, buffer, length); +}