Skip to content

Commit

Permalink
[onert] Allow to use multiple GradientAppliers (#11208)
Browse files Browse the repository at this point in the history
This commit allows to use multiple GradientAppliers in a TrainableFnSequence.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Aug 4, 2023
1 parent 2e8ff58 commit 6ea412b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
18 changes: 14 additions & 4 deletions runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ ops::PoolType convertPoolType(ir::operation::Pool2D::PoolType type_ir)
throw std::runtime_error("train KernelGenerator : Not supported operation yet");
}
}

std::unique_ptr<ops::GradientApplier>
generateGradientApplier(const std::shared_ptr<exec::train::optimizer::Optimizer> optimizer,
const IPortableTensor *gradient, ITrainableTensor *trainable)
{
auto update_fn = std::make_unique<ops::GradientApplier>();
update_fn->configure(optimizer, gradient, trainable);
return update_fn;
}
} // namespace

std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex idx)
Expand All @@ -85,8 +94,9 @@ std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::
assert(_return_fn);
ret->append(std::move(_return_fn));

if (_update_fn)
ret->append(std::move(_update_fn));
for (auto &&update_fn : _update_funcs)
ret->append(std::move(update_fn));
_update_funcs.clear();

for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs())
{
Expand All @@ -110,7 +120,7 @@ KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
std::shared_ptr<exec::train::optimizer::Optimizer> optimizer)
: backend::train::KernelGeneratorBase{tgraph}, _current_layout{tgraph.layout()},
_tensor_reg{tensor_reg},
_external_context(external_context), _optimizer{optimizer}, _update_fn{nullptr}
_external_context(external_context), _optimizer{optimizer}, _update_funcs{}
{
// DO NOTHING
}
Expand All @@ -129,7 +139,7 @@ void KernelGenerator::visit(const ir::train::operation::Conv2D &node)

update_fn->configure(_optimizer, grad_tensor, ker_tensor);

_update_fn = std::move(update_fn);
_update_funcs.emplace_back(generateGradientApplier(_optimizer, grad_tensor, ker_tensor));
}

void KernelGenerator::visit(const ir::train::operation::ElementwiseActivation &node)
Expand Down
2 changes: 1 addition & 1 deletion runtime/onert/backend/train/KernelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase
std::shared_ptr<TensorRegistry> _tensor_reg;
const std::shared_ptr<ExternalContext> _external_context;
std::shared_ptr<exec::train::optimizer::Optimizer> _optimizer;
std::unique_ptr<exec::train::ITrainableFunction> _update_fn;
std::vector<std::unique_ptr<exec::train::ITrainableFunction>> _update_funcs;
};

} // namespace train
Expand Down

0 comments on commit 6ea412b

Please sign in to comment.