diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index f5bf9999671..06a224e1ddc 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -139,16 +139,9 @@ getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph, backend::ITensorRegistry *BackendContext::genTensors() { planForwardTensors(); - - _tensor_builder->allocate(); - - return _tensor_registry.get(); -} - -backend::train::ITensorRegistry *BackendContext::genTrainingTensors() -{ planBackwardTensors(); + _tensor_builder->allocate(); _tensor_builder->allocateBackward(); return _tensor_registry.get(); diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h index 69d17d352c4..6ab458c437a 100644 --- a/runtime/onert/backend/train/BackendContext.h +++ b/runtime/onert/backend/train/BackendContext.h @@ -69,7 +69,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext public: backend::ITensorRegistry *genTensors() override; - backend::train::ITensorRegistry *genTrainingTensors() override; private: void planForwardTensors(); diff --git a/runtime/onert/core/include/backend/train/TrainableBackendContext.h b/runtime/onert/core/include/backend/train/TrainableBackendContext.h index b3a9cdd7d52..36492786b95 100644 --- a/runtime/onert/core/include/backend/train/TrainableBackendContext.h +++ b/runtime/onert/core/include/backend/train/TrainableBackendContext.h @@ -76,7 +76,6 @@ class TrainableBackendContext std::shared_ptr tensor_registry() { return _tensor_registry; } - virtual ITensorRegistry *genTrainingTensors() = 0; virtual backend::ITensorRegistry *genTensors() = 0; virtual FunctionMap genKernels() = 0; diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc index 69483eade12..c415a5557f5 100644 --- a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc @@ -30,19 +30,13 @@ namespace train backend::ITensorRegistry *BackendContext::genTensors() { - // For now, there is no need to generate tensors for forwarding. + // For now, there is no need to generate tensors for forwarding and backwarding. // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`. // `Permute`: Tensor generation is not required. // `IF`, `WHILE`: Not supported yet return tensor_registry().get(); } -backend::train::ITensorRegistry *BackendContext::genTrainingTensors() -{ - // For now, there is no need to generate tensors for backwarding. - return tensor_registry().get(); -} - backend::train::FunctionMap BackendContext::genKernels() { backend::train::FunctionMap ret; diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h index 4782756c31c..c57a8020685 100644 --- a/runtime/onert/core/src/backend/builtin/train/BackendContext.h +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h @@ -47,7 +47,6 @@ class BackendContext : public backend::train::TrainableBackendContext } backend::ITensorRegistry *genTensors() override; - backend::train::ITensorRegistry *genTrainingTensors() override; public: backend::train::FunctionMap genKernels() override; diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index 3cbe5f670eb..0766a72174b 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -746,12 +746,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( pair.second->genTensors(); } - for (auto &&pair : tbackend_contexts) - { - auto tctx = pair.second.get(); - tctx->genTrainingTensors(); - } - prepareMigrantTensors(*lowered_graph, tbackend_contexts); // Give some runtime objects to builtin KernelGenerator