diff --git a/runtime/onert/backend/train/KernelGenerator.cc b/runtime/onert/backend/train/KernelGenerator.cc index 66ed62cae6b..55d38ad8b1c 100644 --- a/runtime/onert/backend/train/KernelGenerator.cc +++ b/runtime/onert/backend/train/KernelGenerator.cc @@ -74,6 +74,15 @@ ops::PoolType convertPoolType(ir::operation::Pool2D::PoolType type_ir) throw std::runtime_error("train KernelGenerator : Not supported operation yet"); } } + +std::unique_ptr +generateGradientApplier(const std::shared_ptr optimizer, + const IPortableTensor *gradient, ITrainableTensor *trainable) +{ + auto update_fn = std::make_unique(); + update_fn->configure(optimizer, gradient, trainable); + return update_fn; +} } // namespace std::unique_ptr KernelGenerator::generate(ir::OperationIndex idx) @@ -85,8 +94,9 @@ std::unique_ptr KernelGenerator::generate(ir:: assert(_return_fn); ret->append(std::move(_return_fn)); - if (_update_fn) - ret->append(std::move(_update_fn)); + for (auto &&update_fn : _update_funcs) + ret->append(std::move(update_fn)); + _update_funcs.clear(); for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) { @@ -110,7 +120,7 @@ KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph, std::shared_ptr optimizer) : backend::train::KernelGeneratorBase{tgraph}, _current_layout{tgraph.layout()}, _tensor_reg{tensor_reg}, - _external_context(external_context), _optimizer{optimizer}, _update_fn{nullptr} + _external_context(external_context), _optimizer{optimizer}, _update_funcs{} { // DO NOTHING } @@ -129,7 +139,7 @@ void KernelGenerator::visit(const ir::train::operation::Conv2D &node) update_fn->configure(_optimizer, grad_tensor, ker_tensor); - _update_fn = std::move(update_fn); + _update_funcs.emplace_back(generateGradientApplier(_optimizer, grad_tensor, ker_tensor)); } void KernelGenerator::visit(const ir::train::operation::ElementwiseActivation &node) diff --git a/runtime/onert/backend/train/KernelGenerator.h b/runtime/onert/backend/train/KernelGenerator.h index 10ba4826f02..7827389be34 100644 --- a/runtime/onert/backend/train/KernelGenerator.h +++ b/runtime/onert/backend/train/KernelGenerator.h @@ -54,7 +54,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase std::shared_ptr _tensor_reg; const std::shared_ptr _external_context; std::shared_ptr _optimizer; - std::unique_ptr _update_fn; + std::vector> _update_funcs; }; } // namespace train