Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert] Revise generating tensors for training #13571

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 94 additions & 77 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "BackendContext.h"

#include "TensorBuilder.h"
#include "TensorPlanner.h"
#include "KernelGenerator.h"
#include "ops/BackPropInitializer.h"

Expand All @@ -41,25 +42,6 @@ ir::OperandInfo createBackwardTensorInfo(const ir::Operand &operand)
operand.isConstant()};
}

// NOTE Even if there are duplicate indices, the duplicate back-propagated tensors may need
// to be updated respectively. So we use a sequence instead of a set.
ir::OperandIndexSequence getBackPropSeq(const ir::train::TrainableGraph &tgraph,
const ir::OperationIndex &op_index)
{
ir::OperandIndexSequence ret;

const auto &op = tgraph.operations().at(op_index);
for (const auto &input : (op.getInputs() | ir::Remove::UNDEFINED))
{
const auto &operand = tgraph.operands().at(input);
// TODO Remove other inputs that are not back-propagated
if (!operand.isConstant() && !tgraph.getInputs().contains(input))
ret.append(input);
}

return ret;
}

void AddBackPropInitializers(const ir::train::TrainableGraph &tgraph, TensorRegistry &tensor_reg,
FunctionMap &fn_map)
{
Expand Down Expand Up @@ -97,99 +79,134 @@ void AddBackPropInitializers(const ir::train::TrainableGraph &tgraph, TensorRegi
}
}
}
} // namespace

backend::ITensorRegistry *BackendContext::genTensors()
util::Set<ir::train::TrainingOperandIndex>
getBackwardTensorList(const ir::train::TrainableGraph &tgraph,
const util::Set<ir::OperandIndex> &external_operands)
{
return basic::train::genTensors(*this, _tensor_builder);
}
util::Set<ir::train::TrainingOperandIndex> ret;

backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
const ir::train::TrainableGraph &tgraph = *trainable_graph();
auto tensor_builder = _tensor_builder;
tgraph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) {
const auto trainable_op = dynamic_cast<const ir::train::TrainableOperation *>(&op);
assert(trainable_op);
if (!trainable_op->isRequiredForBackward())
{
return;
}
for (const auto &ind : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
// TODO Reuse registered tensors when they are planned for memory optimization.
auto border = tgraph.essentialBackwardOrder();
for (const auto op_index : border)
{
const auto &trainable_op = tgraph.operation(op_index);
assert(trainable_op.isRequiredForBackward());
// This assumes that back-propagated tensors of loss outputs are not used
for (const auto &ind :
trainable_op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
if (tensor_builder->isRegisteredBackward(ind))
if (external_operands.contains(ind))
continue;
if (external_operands().contains(ind))

const auto &operand_index = ir::train::TrainingOperandIndex{ind, false};

const auto &training_usedefs = tgraph.trainingUseDefs();
const auto &usedefs = training_usedefs.at(ir::train::TrainingOperandIndex{ind, false});
const bool not_used = usedefs.getTrainingDefs().empty() && usedefs.getTrainingUses().empty();
if (not_used)
continue;

const auto &operand = tgraph.operands().at(ind);
tensor_builder->registerBackwardTensorInfo(ind, createBackwardTensorInfo(operand));
ret.add(operand_index);
}
});
}

for (const auto &op_index : tgraph.essentialBackwardOrder())
return ret;
}

util::Set<DisposableTensorIndex>
getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph,
const util::Set<ir::OperandIndex> &external_operands)
{
util::Set<DisposableTensorIndex> ret;

const auto candidates = getBackwardTensorList(tgraph, external_operands);
for (const auto &backwarding_operand_index : candidates)
{
const auto back_prop_seq = getBackPropSeq(tgraph, op_index);
for (const auto &back_prop_index : back_prop_seq)
{
DisposableTensorIndex disposable_index{op_index, back_prop_index};
const auto &operand = tgraph.operands().at(back_prop_index);
tensor_builder->registerDisposableBackwardTensorInfo(disposable_index,
createBackwardTensorInfo(operand));
}
const auto &operand = tgraph.operands().at(backwarding_operand_index.index());
const auto &training_usedefs = tgraph.trainingUseDefs();
const auto &usedefs = training_usedefs.at(backwarding_operand_index);
const bool is_multiple_defs = usedefs.getTrainingDefs().size() > 1;
if (!operand.isConstant() && is_multiple_defs)
for (const auto &def : usedefs.getTrainingDefs())
ret.add(DisposableTensorIndex{def.index(), backwarding_operand_index.index()});
}

planBackwardTensors();
return ret;
}
} // namespace

tensor_builder->allocateBackward();
backend::ITensorRegistry *BackendContext::genTensors()
{
planForwardTensors();

_tensor_builder->allocate();

return _tensor_registry.get();
}

void BackendContext::planForwardTensors()
backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
// TODO Plan forwarding tensors
planBackwardTensors();

_tensor_builder->allocateBackward();

return _tensor_registry.get();
}

void BackendContext::planBackwardTensors()
void BackendContext::planForwardTensors()
{
const ir::train::TrainableGraph &tgraph = *trainable_graph();
auto tensor_builder = _tensor_builder;
const auto &tgraph = *trainable_graph();

tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
if (external_operands().contains(index))
return;
if (!index.valid())
return;

// TODO Plan tensor builds to reduce peak memory usage
tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
if (tensor_builder->isRegisteredBackward(ind))
tensor_builder->notifyBackwardFirstUse(ind);
_tensor_builder->registerTensorInfo(index, obj.info());
});

planDisposableBackPropTensors();
const auto ctx_data = data();
TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands};
tensor_planner.planTrainableTensors(_tensor_builder.get());
tensor_planner.planNonConstTensors(_tensor_builder.get());
}

void BackendContext::planDisposableBackPropTensors()
void BackendContext::planBackwardTensors()
{
const ir::train::TrainableGraph &tgraph = *trainable_graph();

auto tensor_builder = _tensor_builder;

std::vector<DisposableTensorIndex> prev_seq;
for (const auto &op_index : tgraph.essentialBackwardOrder())
const auto operand_indices = getBackwardTensorList(tgraph, external_operands());
for (const auto &operand_index : operand_indices)
{
for (const auto &index : prev_seq)
{
tensor_builder->notifyDisposableBackPropLastUse(index);
}
if (external_operands().contains(operand_index.index()))
continue;

std::vector<DisposableTensorIndex> cur_seq;
const auto back_prop_indices = getBackPropSeq(tgraph, op_index);
for (const auto &back_prop_index : back_prop_indices)
{
DisposableTensorIndex cur_index{op_index, back_prop_index};
tensor_builder->notifyDisposableBackPropFirstUse(cur_index);
assert(operand_index.valid());

cur_seq.emplace_back(cur_index);
}
assert(!operand_index.is_forward());
const auto &operand = tgraph.operands().at(operand_index.index());
tensor_builder->registerBackwardTensorInfo(operand_index.index(),
createBackwardTensorInfo(operand));
}

prev_seq = cur_seq;
const auto disposable_indices = getDisposableBackPropTensorList(tgraph, external_operands());
for (const auto &disposable_index : disposable_indices)
{
const auto &operand = tgraph.operands().at(disposable_index.operand_index());
tensor_builder->registerDisposableBackwardTensorInfo(disposable_index,
createBackwardTensorInfo(operand));
}

// Plan tensors only in backwarding to reduce peak memory usage
const auto ctx_data = data();
TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands};
tensor_planner.planGradientTensors(tensor_builder.get());
tensor_planner.planBackPropTensors(tensor_builder.get());
tensor_planner.planDisposableBackPropTensors(tensor_builder.get());
}

FunctionMap BackendContext::genKernels()
Expand Down
3 changes: 0 additions & 3 deletions runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
private:
FunctionMap generateFunctionMap();

private:
void planDisposableBackPropTensors();

public:
// TODO Make it private
std::shared_ptr<KernelGenerator> kernel_gen;
Expand Down