diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index 3e72b72d2ef..f5bf9999671 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -17,6 +17,7 @@ #include "BackendContext.h" #include "TensorBuilder.h" +#include "TensorPlanner.h" #include "KernelGenerator.h" #include "ops/BackPropInitializer.h" @@ -41,25 +42,6 @@ ir::OperandInfo createBackwardTensorInfo(const ir::Operand &operand) operand.isConstant()}; } -// NOTE Even if there are duplicate indices, the duplicate back-propagated tensors may need -// to be updated respectively. So we use a sequence instead of a set. -ir::OperandIndexSequence getBackPropSeq(const ir::train::TrainableGraph &tgraph, - const ir::OperationIndex &op_index) -{ - ir::OperandIndexSequence ret; - - const auto &op = tgraph.operations().at(op_index); - for (const auto &input : (op.getInputs() | ir::Remove::UNDEFINED)) - { - const auto &operand = tgraph.operands().at(input); - // TODO Remove other inputs that are not back-propagated - if (!operand.isConstant() && !tgraph.getInputs().contains(input)) - ret.append(input); - } - - return ret; -} - void AddBackPropInitializers(const ir::train::TrainableGraph &tgraph, TensorRegistry &tensor_reg, FunctionMap &fn_map) { @@ -97,99 +79,134 @@ void AddBackPropInitializers(const ir::train::TrainableGraph &tgraph, TensorRegi } } } -} // namespace -backend::ITensorRegistry *BackendContext::genTensors() +util::Set +getBackwardTensorList(const ir::train::TrainableGraph &tgraph, + const util::Set &external_operands) { - return basic::train::genTensors(*this, _tensor_builder); -} + util::Set ret; -backend::train::ITensorRegistry *BackendContext::genTrainingTensors() -{ - const ir::train::TrainableGraph &tgraph = *trainable_graph(); - auto tensor_builder = _tensor_builder; - tgraph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) { - const auto trainable_op = dynamic_cast(&op); - assert(trainable_op); - if (!trainable_op->isRequiredForBackward()) - { - return; - } - for (const auto &ind : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + // TODO Reuse registered tensors when they are planned for memory optimization. + auto border = tgraph.essentialBackwardOrder(); + for (const auto op_index : border) + { + const auto &trainable_op = tgraph.operation(op_index); + assert(trainable_op.isRequiredForBackward()); + // This assumes that back-propagated tensors of loss outputs are not used + for (const auto &ind : + trainable_op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { - if (tensor_builder->isRegisteredBackward(ind)) + if (external_operands.contains(ind)) continue; - if (external_operands().contains(ind)) + + const auto &operand_index = ir::train::TrainingOperandIndex{ind, false}; + + const auto &training_usedefs = tgraph.trainingUseDefs(); + const auto &usedefs = training_usedefs.at(ir::train::TrainingOperandIndex{ind, false}); + const bool not_used = usedefs.getTrainingDefs().empty() && usedefs.getTrainingUses().empty(); + if (not_used) continue; - const auto &operand = tgraph.operands().at(ind); - tensor_builder->registerBackwardTensorInfo(ind, createBackwardTensorInfo(operand)); + ret.add(operand_index); } - }); + } - for (const auto &op_index : tgraph.essentialBackwardOrder()) + return ret; +} + +util::Set +getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph, + const util::Set &external_operands) +{ + util::Set ret; + + const auto candidates = getBackwardTensorList(tgraph, external_operands); + for (const auto &backwarding_operand_index : candidates) { - const auto back_prop_seq = getBackPropSeq(tgraph, op_index); - for (const auto &back_prop_index : back_prop_seq) - { - DisposableTensorIndex disposable_index{op_index, back_prop_index}; - const auto &operand = tgraph.operands().at(back_prop_index); - tensor_builder->registerDisposableBackwardTensorInfo(disposable_index, - createBackwardTensorInfo(operand)); - } + const auto &operand = tgraph.operands().at(backwarding_operand_index.index()); + const auto &training_usedefs = tgraph.trainingUseDefs(); + const auto &usedefs = training_usedefs.at(backwarding_operand_index); + const bool is_multiple_defs = usedefs.getTrainingDefs().size() > 1; + if (!operand.isConstant() && is_multiple_defs) + for (const auto &def : usedefs.getTrainingDefs()) + ret.add(DisposableTensorIndex{def.index(), backwarding_operand_index.index()}); } - planBackwardTensors(); + return ret; +} +} // namespace - tensor_builder->allocateBackward(); +backend::ITensorRegistry *BackendContext::genTensors() +{ + planForwardTensors(); + + _tensor_builder->allocate(); return _tensor_registry.get(); } -void BackendContext::planForwardTensors() +backend::train::ITensorRegistry *BackendContext::genTrainingTensors() { - // TODO Plan forwarding tensors + planBackwardTensors(); + + _tensor_builder->allocateBackward(); + + return _tensor_registry.get(); } -void BackendContext::planBackwardTensors() +void BackendContext::planForwardTensors() { - const ir::train::TrainableGraph &tgraph = *trainable_graph(); - auto tensor_builder = _tensor_builder; + const auto &tgraph = *trainable_graph(); + + tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) { + if (external_operands().contains(index)) + return; + if (!index.valid()) + return; - // TODO Plan tensor builds to reduce peak memory usage - tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { - if (tensor_builder->isRegisteredBackward(ind)) - tensor_builder->notifyBackwardFirstUse(ind); + _tensor_builder->registerTensorInfo(index, obj.info()); }); - planDisposableBackPropTensors(); + const auto ctx_data = data(); + TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands}; + tensor_planner.planTrainableTensors(_tensor_builder.get()); + tensor_planner.planNonConstTensors(_tensor_builder.get()); } -void BackendContext::planDisposableBackPropTensors() +void BackendContext::planBackwardTensors() { const ir::train::TrainableGraph &tgraph = *trainable_graph(); + auto tensor_builder = _tensor_builder; - std::vector prev_seq; - for (const auto &op_index : tgraph.essentialBackwardOrder()) + const auto operand_indices = getBackwardTensorList(tgraph, external_operands()); + for (const auto &operand_index : operand_indices) { - for (const auto &index : prev_seq) - { - tensor_builder->notifyDisposableBackPropLastUse(index); - } + if (external_operands().contains(operand_index.index())) + continue; - std::vector cur_seq; - const auto back_prop_indices = getBackPropSeq(tgraph, op_index); - for (const auto &back_prop_index : back_prop_indices) - { - DisposableTensorIndex cur_index{op_index, back_prop_index}; - tensor_builder->notifyDisposableBackPropFirstUse(cur_index); + assert(operand_index.valid()); - cur_seq.emplace_back(cur_index); - } + assert(!operand_index.is_forward()); + const auto &operand = tgraph.operands().at(operand_index.index()); + tensor_builder->registerBackwardTensorInfo(operand_index.index(), + createBackwardTensorInfo(operand)); + } - prev_seq = cur_seq; + const auto disposable_indices = getDisposableBackPropTensorList(tgraph, external_operands()); + for (const auto &disposable_index : disposable_indices) + { + const auto &operand = tgraph.operands().at(disposable_index.operand_index()); + tensor_builder->registerDisposableBackwardTensorInfo(disposable_index, + createBackwardTensorInfo(operand)); } + + // Plan tensors only in backwarding to reduce peak memory usage + const auto ctx_data = data(); + TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands}; + tensor_planner.planGradientTensors(tensor_builder.get()); + tensor_planner.planBackPropTensors(tensor_builder.get()); + tensor_planner.planDisposableBackPropTensors(tensor_builder.get()); } FunctionMap BackendContext::genKernels() diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h index 6c36a1e8924..69d17d352c4 100644 --- a/runtime/onert/backend/train/BackendContext.h +++ b/runtime/onert/backend/train/BackendContext.h @@ -85,9 +85,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext private: FunctionMap generateFunctionMap(); -private: - void planDisposableBackPropTensors(); - public: // TODO Make it private std::shared_ptr kernel_gen;