diff --git a/runtime/onert/backend/train/KernelGenerator.cc b/runtime/onert/backend/train/KernelGenerator.cc index 1e5d83b4565..f3df39735e0 100644 --- a/runtime/onert/backend/train/KernelGenerator.cc +++ b/runtime/onert/backend/train/KernelGenerator.cc @@ -80,21 +80,22 @@ generateBackPropAccumulator(const IPortableTensor *disposable, BackPropTensor *g return update_fn; } -void appendBackPropAccumulator(const ir::train::ITrainableOperation &op, - const ir::OperationIndex &op_index, TensorRegistry *tensor_reg, - exec::train::TrainableFnSequence *seq) +void appendBackPropAccumulators(const ir::train::ITrainableOperation &op, + const ir::OperationIndex &op_index, TensorRegistry *tensor_reg, + exec::train::TrainableFnSequence *seq) { + if (!op.isRequiredForBackward()) + return; + for (const auto &input_index : (op.getInputs() | ir::Remove::UNDEFINED)) { - if (op.isRequiredForBackward()) + const auto disposable = + tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, input_index}); + if (disposable != nullptr) { - const auto disposable = - tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, input_index}); - if (disposable != nullptr) - { - auto back_prop = tensor_reg->getBackPropTensor(input_index); - seq->append(generateBackPropAccumulator(disposable, back_prop)); - } + auto back_prop = tensor_reg->getBackPropTensor(input_index); + assert(back_prop); + seq->append(generateBackPropAccumulator(disposable, back_prop)); } } } @@ -111,13 +112,16 @@ generateGradientApplier(const exec::train::optimizer::Optimizer *optimizer, std::unique_ptr KernelGenerator::generate(ir::OperationIndex idx) { + // NOTE This function is related to planning tensors. If you change this function, you should + // also consider to change planning tensors. + auto ret = std::make_unique(); const auto &op = _tgraph.operation(idx); - // NOTE appendBackPropAccumulator() must be called before appending _return_fn to + // NOTE appendBackPropAccumulators() must be called before appending _return_fn to // TrainableFnSequence as long as both are appended to the same TrainableFnSequence. - appendBackPropAccumulator(op, idx, _tensor_reg.get(), ret.get()); + appendBackPropAccumulators(op, idx, _tensor_reg.get(), ret.get()); op.accept(*this); assert(_return_fn); @@ -604,19 +608,28 @@ void KernelGenerator::visit(const ir::train::operation::Softmax &node) _return_fn = std::move(fn); } -IPortableTensor *KernelGenerator::getBackPropIn(const ir::Operation &node, +IPortableTensor *KernelGenerator::getBackPropIn(const ir::IOperation &node, const ir::OperandIndex &operand_index) { const auto &op_index = _node_to_idx[&node]; + const auto backwarding_operand_index = ir::train::TrainingOperandIndex{operand_index, false}; - auto temp_tensor = + const auto disposable_tensor = _tensor_reg->getDisposableBackPropTensor(DisposableTensorIndex{op_index, operand_index}); - if (temp_tensor == nullptr) + if (disposable_tensor != nullptr) { - temp_tensor = _tensor_reg->getBackPropTensor(operand_index); + const auto &training_usedefs = _tgraph.trainingUseDefs().at(backwarding_operand_index); + UNUSED_RELEASE(training_usedefs); + assert(std::count_if(training_usedefs.getTrainingDefs().begin(), + training_usedefs.getTrainingDefs().end(), + [&](const ir::train::TrainingOperationIndex &op_index) { + return _tgraph.operation(op_index.index()).isRequiredForBackward(); + }) > 1); + + return disposable_tensor; } - - return temp_tensor; + else + return _tensor_reg->getBackPropTensor(operand_index); } IPortableTensor *KernelGenerator::getBackPropOut(const ir::OperandIndex &output_index) diff --git a/runtime/onert/backend/train/KernelGenerator.h b/runtime/onert/backend/train/KernelGenerator.h index 329903b0b21..77f47e2588e 100644 --- a/runtime/onert/backend/train/KernelGenerator.h +++ b/runtime/onert/backend/train/KernelGenerator.h @@ -59,8 +59,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase void visit(const ir::train::operation::Softmax &node) override; private: - IPortableTensor *getBackPropIn(const ir::Operation &op_index, - const ir::OperandIndex &operand_index); + IPortableTensor *getBackPropIn(const ir::IOperation &node, const ir::OperandIndex &operand_index); IPortableTensor *getBackPropOut(const ir::OperandIndex &index); private: