diff --git a/onert-micro/onert-micro/include/OMTrainingInterpreter.h b/onert-micro/onert-micro/include/OMTrainingInterpreter.h index c2c849f4c7f..0c1a35defc8 100644 --- a/onert-micro/onert-micro/include/OMTrainingInterpreter.h +++ b/onert-micro/onert-micro/include/OMTrainingInterpreter.h @@ -85,6 +85,13 @@ class OMTrainingInterpreter // Load current status from checkpoint and save it in current model and in current config OMStatus loadCheckpoint(OMConfig &config, const char *load_path); + + OMStatus run() { return _training_runtime_module.run(); } + OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); } + + void *getInputData(uint32_t position); + void *getInputDataAt(uint32_t position); + void *getOutputDataAt(uint32_t position); }; } // namespace onert_micro diff --git a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h index 081ff8c6d9a..d9201374b5e 100644 --- a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h @@ -90,6 +90,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule // Load checkpoints data and save it in model data and in config // To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037 OMStatus loadCheckpointData(OMConfig &config, const char *data); + + void *getInputData(int32_t index); }; } // namespace core diff --git a/onert-micro/onert-micro/src/OMTrainingInterpreter.cpp b/onert-micro/onert-micro/src/OMTrainingInterpreter.cpp index c3f7bbc6b05..e5420909a34 100644 --- a/onert-micro/onert-micro/src/OMTrainingInterpreter.cpp +++ b/onert-micro/onert-micro/src/OMTrainingInterpreter.cpp @@ -152,3 +152,18 @@ OMStatus OMTrainingInterpreter::saveCheckpoint(const OMConfig &config, const cha return Ok; } + +void *OMTrainingInterpreter::getInputDataAt(uint32_t position) +{ + return _training_runtime_module.getInputDataAt(position); +} + +void *OMTrainingInterpreter::getOutputDataAt(uint32_t position) +{ + return _training_runtime_module.getOutputDataAt(position); +} + +void *OMTrainingInterpreter::getInputData(uint32_t position) +{ + return _training_runtime_module.getInputData(position); +} diff --git a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp index 336870b85cc..35ee91a47ed 100644 --- a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp @@ -302,3 +302,8 @@ OMStatus OMTrainingRuntimeModule::loadCheckpointData(OMConfig &config, const cha return status; } + +void *OMTrainingRuntimeModule::getInputData(int32_t index) +{ + return _training_handler.getInputData(index); +}