From 30b87c7a01b4a820ed59623c8551d0e752ade4ba Mon Sep 17 00:00:00 2001 From: Balyshev Artem <43214667+BalyshevArtem@users.noreply.github.com> Date: Thu, 20 Jun 2024 12:38:58 +0300 Subject: [PATCH] [onert-micro] Enable Training (#13256) This pr enables training via CMakeLists. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- onert-micro/onert-micro/CMakeLists.txt | 7 ++++++ onert-micro/onert-micro/include/OMConfig.h | 1 + .../include/pal/mcu/KernelsToTrain.lst | 5 ++++ onert-micro/onert-micro/src/CMakeLists.txt | 23 +++++++++++++++++++ .../onert-micro/src/core/CMakeLists.txt | 5 ++++ .../src/core/memory/OMRuntimeAllocator.cpp | 3 +++ 6 files changed, 44 insertions(+) diff --git a/onert-micro/onert-micro/CMakeLists.txt b/onert-micro/onert-micro/CMakeLists.txt index 8cfe856962e..a48ad1628fe 100644 --- a/onert-micro/onert-micro/CMakeLists.txt +++ b/onert-micro/onert-micro/CMakeLists.txt @@ -13,6 +13,13 @@ else() set(KERNEL_REGISTER_FILE ${OM_KERNELS_BUILD_LIST}) endif() +# if no path is specified, use classic to KernelsToBuild list, else use generated one +if (NOT OM_KERNELS_TRAIN_LIST) + set(KERNEL_TRAIN_REGISTER_FILE "${OM_PAL_DIR}/KernelsToTrain.lst") +else() + set(KERNEL_TRAIN_REGISTER_FILE ${OM_KERNELS_TRAIN_LIST}) +endif() + set(PASS_REGISTER_FILE "${OM_INCLUDE_DIR}/optimize/BuildPass.lst") set(CUSTOM_KERNEL_REGISTER_FILE "${OM_PAL_DIR}/CustomKernelsToBuild.lst") diff --git a/onert-micro/onert-micro/include/OMConfig.h b/onert-micro/onert-micro/include/OMConfig.h index f7fa9704bfe..0ffbf024b1f 100644 --- a/onert-micro/onert-micro/include/OMConfig.h +++ b/onert-micro/onert-micro/include/OMConfig.h @@ -78,6 +78,7 @@ struct OMTrainingContext float epsilon = 10e-8; uint32_t num_step = 0; uint32_t num_epoch = 0; + uint32_t epochs = 0; }; /* diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst index e69de29bb2d..26476df574f 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst @@ -0,0 +1,5 @@ +REGISTER_TRAIN_KERNEL(FULLY_CONNECTED, FullyConnected) +REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax) +REGISTER_TRAIN_KERNEL(RESHAPE, Reshape) +REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D) +REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D) diff --git a/onert-micro/onert-micro/src/CMakeLists.txt b/onert-micro/onert-micro/src/CMakeLists.txt index cf6b52d8bd3..903e93e7e5a 100644 --- a/onert-micro/onert-micro/src/CMakeLists.txt +++ b/onert-micro/onert-micro/src/CMakeLists.txt @@ -10,9 +10,14 @@ set(OM_INCLUDE_IMPORT_DIR "${OM_INCLUDE_DIR}/import") #define optimize path set(OM_SOURCE_OPTIMIZE_DIR "${OM_SOURCE_DIR}/optimize") set(OM_INCLUDE_OPTIMIZE_DIR "${OM_INCLUDE_DIR}/optimize") +#define train path +set(OM_SOURCE_TRAIN_DIR "${OM_SOURCE_DIR}/train") +set(OM_INCLUDE_TRAIN_DIR "${OM_INCLUDE_DIR}/train") #OM_Interpreter lib binary name set(OM_INTERPRETER_LIB "onert_micro_interpreter") +#OM_Training_Interpreter lib binary name +set(OM_TRAINING_INTERPRETER_LIB "onert_micro_training_interpreter") #Core lib binary name set(OM_CORE_LIB "onert_micro_core${OM_SUFFIX}") #Execute lib binary name @@ -23,6 +28,8 @@ set(OM_IMPORT_LIB "onert_micro_import${OM_SUFFIX}") set(OM_OPTIMIZE_LIB "onert_micro_optimize${OM_SUFFIX}") #PAL lib binary name set(OM_PAL_LIB "onert_micro_pal${OM_SUFFIX}") +#Train lib binary name +set(OM_TRAIN_LIB "onert_micro_train${OM_SUFFIX}") message(STATUS "ONERT MICRO BEGIN") @@ -45,4 +52,20 @@ add_library(${OM_INTERPRETER_LIB} STATIC OMInterpreter.cpp) target_include_directories(${OM_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}") target_link_libraries(${OM_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB}) +# Training part +message (STATUS "ONERT MICRO TRAINING BEGIN") + +#build train lib +add_subdirectory(${OM_SOURCE_TRAIN_DIR}) + +target_link_libraries(${OM_CORE_LIB} PUBLIC ${OM_TRAIN_LIB}) + +add_library(${OM_TRAINING_INTERPRETER_LIB} STATIC OMTrainingInterpreter.cpp) + +target_include_directories(${OM_TRAINING_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}") +target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB}) +target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_TRAIN_LIB}) + +message (STATUS "ONERT MICRO TRAINING END") + message(STATUS "ONERT MICRO FINISHED") diff --git a/onert-micro/onert-micro/src/core/CMakeLists.txt b/onert-micro/onert-micro/src/core/CMakeLists.txt index 86935c3913d..4ed5c11df70 100644 --- a/onert-micro/onert-micro/src/core/CMakeLists.txt +++ b/onert-micro/onert-micro/src/core/CMakeLists.txt @@ -4,10 +4,15 @@ set(SOURCES OMKernelType.cpp OMRuntimeContext.cpp OMRuntimeStorage.cpp + OMTrainingRuntimeModule.cpp OMRuntimeGraph.cpp OMRuntimeModule.cpp OMUtils.cpp OMDataType.cpp + train/OMTrainingHandler.cpp + train/OMTrainingStorage.cpp + train/OMCheckpointSaver.cpp + train/OMCheckpointLoader.cpp memory/OMMemoryManager.cpp memory/OMRuntimeAllocator.cpp reader/OMCircleReader.cpp diff --git a/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp b/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp index f867c6823d8..acbbc11b63a 100644 --- a/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp +++ b/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp @@ -94,6 +94,9 @@ OMStatus OMRuntimeAllocator::deallocate(size_t kernel_index, OMRuntimeStorage *s { uint8_t *allocated_data = nullptr; OMStatus status = storage->getDataByTensorIndex(&allocated_data, tensor_index); + // To continue deallocate due to current tensor is not saved in storage + if (allocated_data == nullptr) + continue; if (status != Ok) return status;