Skip to content

Commit

Permalink
[onert-micro] Enable Training (#13256)
Browse files Browse the repository at this point in the history
This pr enables training via CMakeLists.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Jun 20, 2024
1 parent 8190370 commit 30b87c7
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 0 deletions.
7 changes: 7 additions & 0 deletions onert-micro/onert-micro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ else()
set(KERNEL_REGISTER_FILE ${OM_KERNELS_BUILD_LIST})
endif()

# if no path is specified, use classic to KernelsToBuild list, else use generated one
if (NOT OM_KERNELS_TRAIN_LIST)
set(KERNEL_TRAIN_REGISTER_FILE "${OM_PAL_DIR}/KernelsToTrain.lst")
else()
set(KERNEL_TRAIN_REGISTER_FILE ${OM_KERNELS_TRAIN_LIST})
endif()

set(PASS_REGISTER_FILE "${OM_INCLUDE_DIR}/optimize/BuildPass.lst")
set(CUSTOM_KERNEL_REGISTER_FILE "${OM_PAL_DIR}/CustomKernelsToBuild.lst")

Expand Down
1 change: 1 addition & 0 deletions onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct OMTrainingContext
float epsilon = 10e-8;
uint32_t num_step = 0;
uint32_t num_epoch = 0;
uint32_t epochs = 0;
};

/*
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
REGISTER_TRAIN_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax)
REGISTER_TRAIN_KERNEL(RESHAPE, Reshape)
REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D)
REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D)
23 changes: 23 additions & 0 deletions onert-micro/onert-micro/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ set(OM_INCLUDE_IMPORT_DIR "${OM_INCLUDE_DIR}/import")
#define optimize path
set(OM_SOURCE_OPTIMIZE_DIR "${OM_SOURCE_DIR}/optimize")
set(OM_INCLUDE_OPTIMIZE_DIR "${OM_INCLUDE_DIR}/optimize")
#define train path
set(OM_SOURCE_TRAIN_DIR "${OM_SOURCE_DIR}/train")
set(OM_INCLUDE_TRAIN_DIR "${OM_INCLUDE_DIR}/train")

#OM_Interpreter lib binary name
set(OM_INTERPRETER_LIB "onert_micro_interpreter")
#OM_Training_Interpreter lib binary name
set(OM_TRAINING_INTERPRETER_LIB "onert_micro_training_interpreter")
#Core lib binary name
set(OM_CORE_LIB "onert_micro_core${OM_SUFFIX}")
#Execute lib binary name
Expand All @@ -23,6 +28,8 @@ set(OM_IMPORT_LIB "onert_micro_import${OM_SUFFIX}")
set(OM_OPTIMIZE_LIB "onert_micro_optimize${OM_SUFFIX}")
#PAL lib binary name
set(OM_PAL_LIB "onert_micro_pal${OM_SUFFIX}")
#Train lib binary name
set(OM_TRAIN_LIB "onert_micro_train${OM_SUFFIX}")

message(STATUS "ONERT MICRO BEGIN")

Expand All @@ -45,4 +52,20 @@ add_library(${OM_INTERPRETER_LIB} STATIC OMInterpreter.cpp)
target_include_directories(${OM_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}")
target_link_libraries(${OM_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB})

# Training part
message (STATUS "ONERT MICRO TRAINING BEGIN")

#build train lib
add_subdirectory(${OM_SOURCE_TRAIN_DIR})

target_link_libraries(${OM_CORE_LIB} PUBLIC ${OM_TRAIN_LIB})

add_library(${OM_TRAINING_INTERPRETER_LIB} STATIC OMTrainingInterpreter.cpp)

target_include_directories(${OM_TRAINING_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}")
target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB})
target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_TRAIN_LIB})

message (STATUS "ONERT MICRO TRAINING END")

message(STATUS "ONERT MICRO FINISHED")
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ set(SOURCES
OMKernelType.cpp
OMRuntimeContext.cpp
OMRuntimeStorage.cpp
OMTrainingRuntimeModule.cpp
OMRuntimeGraph.cpp
OMRuntimeModule.cpp
OMUtils.cpp
OMDataType.cpp
train/OMTrainingHandler.cpp
train/OMTrainingStorage.cpp
train/OMCheckpointSaver.cpp
train/OMCheckpointLoader.cpp
memory/OMMemoryManager.cpp
memory/OMRuntimeAllocator.cpp
reader/OMCircleReader.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ OMStatus OMRuntimeAllocator::deallocate(size_t kernel_index, OMRuntimeStorage *s
{
uint8_t *allocated_data = nullptr;
OMStatus status = storage->getDataByTensorIndex(&allocated_data, tensor_index);
// To continue deallocate due to current tensor is not saved in storage
if (allocated_data == nullptr)
continue;
if (status != Ok)
return status;

Expand Down

0 comments on commit 30b87c7

Please sign in to comment.