Skip to content

Commit

Permalink
[onert] Generate BackPropAccumulator only if necessary (#13570)
Browse files Browse the repository at this point in the history
This commit changes BackPropAccumulators to be generated only if the correspoding dispoable tensors exist.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Aug 5, 2024
1 parent 3a46768 commit 62eaa3d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
51 changes: 32 additions & 19 deletions runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ generateBackPropAccumulator(const IPortableTensor *disposable, BackPropTensor *g
return update_fn;
}

void appendBackPropAccumulator(const ir::train::ITrainableOperation &op,
const ir::OperationIndex &op_index, TensorRegistry *tensor_reg,
exec::train::TrainableFnSequence *seq)
void appendBackPropAccumulators(const ir::train::ITrainableOperation &op,
const ir::OperationIndex &op_index, TensorRegistry *tensor_reg,
exec::train::TrainableFnSequence *seq)
{
if (!op.isRequiredForBackward())
return;

for (const auto &input_index : (op.getInputs() | ir::Remove::UNDEFINED))
{
if (op.isRequiredForBackward())
const auto disposable =
tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, input_index});
if (disposable != nullptr)
{
const auto disposable =
tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, input_index});
if (disposable != nullptr)
{
auto back_prop = tensor_reg->getBackPropTensor(input_index);
seq->append(generateBackPropAccumulator(disposable, back_prop));
}
auto back_prop = tensor_reg->getBackPropTensor(input_index);
assert(back_prop);
seq->append(generateBackPropAccumulator(disposable, back_prop));
}
}
}
Expand All @@ -111,13 +112,16 @@ generateGradientApplier(const exec::train::optimizer::Optimizer *optimizer,

std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex idx)
{
// NOTE This function is related to planning tensors. If you change this function, you should
// also consider to change planning tensors.

auto ret = std::make_unique<exec::train::TrainableFnSequence>();

const auto &op = _tgraph.operation(idx);

// NOTE appendBackPropAccumulator() must be called before appending _return_fn to
// NOTE appendBackPropAccumulators() must be called before appending _return_fn to
// TrainableFnSequence as long as both are appended to the same TrainableFnSequence.
appendBackPropAccumulator(op, idx, _tensor_reg.get(), ret.get());
appendBackPropAccumulators(op, idx, _tensor_reg.get(), ret.get());

op.accept(*this);
assert(_return_fn);
Expand Down Expand Up @@ -604,19 +608,28 @@ void KernelGenerator::visit(const ir::train::operation::Softmax &node)
_return_fn = std::move(fn);
}

IPortableTensor *KernelGenerator::getBackPropIn(const ir::Operation &node,
IPortableTensor *KernelGenerator::getBackPropIn(const ir::IOperation &node,
const ir::OperandIndex &operand_index)
{
const auto &op_index = _node_to_idx[&node];
const auto backwarding_operand_index = ir::train::TrainingOperandIndex{operand_index, false};

auto temp_tensor =
const auto disposable_tensor =
_tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, operand_index});
if (temp_tensor == nullptr)
if (disposable_tensor != nullptr)
{
temp_tensor = _tensor_reg->getBackPropTensor(operand_index);
const auto &training_usedefs = _tgraph.trainingUseDefs().at(backwarding_operand_index);
UNUSED_RELEASE(training_usedefs);
assert(std::count_if(training_usedefs.getTrainingDefs().begin(),
training_usedefs.getTrainingDefs().end(),
[&](const ir::train::TrainingOperationIndex &op_index) {
return _tgraph.operation(op_index.index()).isRequiredForBackward();
}) > 1);

return disposable_tensor;
}

return temp_tensor;
else
return _tensor_reg->getBackPropTensor(operand_index);
}

IPortableTensor *KernelGenerator::getBackPropOut(const ir::OperandIndex &output_index)
Expand Down
3 changes: 1 addition & 2 deletions runtime/onert/backend/train/KernelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase
void visit(const ir::train::operation::Softmax &node) override;

private:
IPortableTensor *getBackPropIn(const ir::Operation &op_index,
const ir::OperandIndex &operand_index);
IPortableTensor *getBackPropIn(const ir::IOperation &node, const ir::OperandIndex &operand_index);
IPortableTensor *getBackPropOut(const ir::OperandIndex &index);

private:
Expand Down

0 comments on commit 62eaa3d

Please sign in to comment.