Skip to content

Commit

Permalink
aux funcs in OMTrainingInterpreter (#13263)
Browse files Browse the repository at this point in the history
* aux funcs in OMTrainingInterpreter

- run(), allocateInputs(), getInputData, getInputDataAt, getOutputDataAt

ONE-DCO-1.0-Signed-off-by: Chunseok Lee <[email protected]>

* fix format

---------

Co-authored-by: chunseoklee <[email protected]>
  • Loading branch information
chunseoklee and chunseoklee authored Jun 21, 2024
1 parent 506f7b4 commit 75e6d15
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
7 changes: 7 additions & 0 deletions onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ class OMTrainingInterpreter

// Load current status from checkpoint and save it in current model and in current config
OMStatus loadCheckpoint(OMConfig &config, const char *load_path);

OMStatus run() { return _training_runtime_module.run(); }
OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); }

void *getInputData(uint32_t position);
void *getInputDataAt(uint32_t position);
void *getOutputDataAt(uint32_t position);
};

} // namespace onert_micro
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
// Load checkpoints data and save it in model data and in config
// To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037
OMStatus loadCheckpointData(OMConfig &config, const char *data);

void *getInputData(int32_t index);
};

} // namespace core
Expand Down
15 changes: 15 additions & 0 deletions onert-micro/onert-micro/src/OMTrainingInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,18 @@ OMStatus OMTrainingInterpreter::saveCheckpoint(const OMConfig &config, const cha

return Ok;
}

void *OMTrainingInterpreter::getInputDataAt(uint32_t position)
{
return _training_runtime_module.getInputDataAt(position);
}

void *OMTrainingInterpreter::getOutputDataAt(uint32_t position)
{
return _training_runtime_module.getOutputDataAt(position);
}

void *OMTrainingInterpreter::getInputData(uint32_t position)
{
return _training_runtime_module.getInputData(position);
}
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,8 @@ OMStatus OMTrainingRuntimeModule::loadCheckpointData(OMConfig &config, const cha

return status;
}

void *OMTrainingRuntimeModule::getInputData(int32_t index)
{
return _training_handler.getInputData(index);
}

0 comments on commit 75e6d15

Please sign in to comment.