Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] Enable Training #13256

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onert-micro/onert-micro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ else()
set(KERNEL_REGISTER_FILE ${OM_KERNELS_BUILD_LIST})
endif()

# if no path is specified, use classic to KernelsToBuild list, else use generated one
if (NOT OM_KERNELS_TRAIN_LIST)
set(KERNEL_TRAIN_REGISTER_FILE "${OM_PAL_DIR}/KernelsToTrain.lst")
else()
set(KERNEL_TRAIN_REGISTER_FILE ${OM_KERNELS_TRAIN_LIST})
endif()

set(PASS_REGISTER_FILE "${OM_INCLUDE_DIR}/optimize/BuildPass.lst")
set(CUSTOM_KERNEL_REGISTER_FILE "${OM_PAL_DIR}/CustomKernelsToBuild.lst")

Expand Down
1 change: 1 addition & 0 deletions onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct OMTrainingContext
float epsilon = 10e-8;
uint32_t num_step = 0;
uint32_t num_epoch = 0;
uint32_t epochs = 0;
};

/*
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
REGISTER_TRAIN_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax)
REGISTER_TRAIN_KERNEL(RESHAPE, Reshape)
REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D)
REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D)
23 changes: 23 additions & 0 deletions onert-micro/onert-micro/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ set(OM_INCLUDE_IMPORT_DIR "${OM_INCLUDE_DIR}/import")
#define optimize path
set(OM_SOURCE_OPTIMIZE_DIR "${OM_SOURCE_DIR}/optimize")
set(OM_INCLUDE_OPTIMIZE_DIR "${OM_INCLUDE_DIR}/optimize")
#define train path
set(OM_SOURCE_TRAIN_DIR "${OM_SOURCE_DIR}/train")
set(OM_INCLUDE_TRAIN_DIR "${OM_INCLUDE_DIR}/train")

#OM_Interpreter lib binary name
set(OM_INTERPRETER_LIB "onert_micro_interpreter")
#OM_Training_Interpreter lib binary name
set(OM_TRAINING_INTERPRETER_LIB "onert_micro_training_interpreter")
#Core lib binary name
set(OM_CORE_LIB "onert_micro_core${OM_SUFFIX}")
#Execute lib binary name
Expand All @@ -23,6 +28,8 @@ set(OM_IMPORT_LIB "onert_micro_import${OM_SUFFIX}")
set(OM_OPTIMIZE_LIB "onert_micro_optimize${OM_SUFFIX}")
#PAL lib binary name
set(OM_PAL_LIB "onert_micro_pal${OM_SUFFIX}")
#Train lib binary name
set(OM_TRAIN_LIB "onert_micro_train${OM_SUFFIX}")

message(STATUS "ONERT MICRO BEGIN")

Expand All @@ -45,4 +52,20 @@ add_library(${OM_INTERPRETER_LIB} STATIC OMInterpreter.cpp)
target_include_directories(${OM_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}")
target_link_libraries(${OM_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB})

# Training part
message (STATUS "ONERT MICRO TRAINING BEGIN")

#build train lib
add_subdirectory(${OM_SOURCE_TRAIN_DIR})

target_link_libraries(${OM_CORE_LIB} PUBLIC ${OM_TRAIN_LIB})

add_library(${OM_TRAINING_INTERPRETER_LIB} STATIC OMTrainingInterpreter.cpp)

target_include_directories(${OM_TRAINING_INTERPRETER_LIB} PUBLIC "${OM_INCLUDE_DIR}")
target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_CORE_LIB})
target_link_libraries(${OM_TRAINING_INTERPRETER_LIB} PUBLIC ${OM_TRAIN_LIB})

message (STATUS "ONERT MICRO TRAINING END")

message(STATUS "ONERT MICRO FINISHED")
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ set(SOURCES
OMKernelType.cpp
OMRuntimeContext.cpp
OMRuntimeStorage.cpp
OMTrainingRuntimeModule.cpp
OMRuntimeGraph.cpp
OMRuntimeModule.cpp
OMUtils.cpp
OMDataType.cpp
train/OMTrainingHandler.cpp
train/OMTrainingStorage.cpp
train/OMCheckpointSaver.cpp
train/OMCheckpointLoader.cpp
memory/OMMemoryManager.cpp
memory/OMRuntimeAllocator.cpp
reader/OMCircleReader.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ OMStatus OMRuntimeAllocator::deallocate(size_t kernel_index, OMRuntimeStorage *s
{
uint8_t *allocated_data = nullptr;
OMStatus status = storage->getDataByTensorIndex(&allocated_data, tensor_index);
// To continue deallocate due to current tensor is not saved in storage
if (allocated_data == nullptr)
continue;
if (status != Ok)
return status;

Expand Down
Loading