Skip to content

Commit

Permalink
[onert/backend] Add LayerScopeTensor Planning
Browse files Browse the repository at this point in the history
This PR adds TensorPlanner::planLayerScopeTensors() and its caller part in BackendContext.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>

--------------------------------------

draft : #13486
  • Loading branch information
zetwhite committed Oct 14, 2024
1 parent b70d6eb commit 466f6c9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 0 deletions.
15 changes: 15 additions & 0 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ FunctionMap BackendContext::gen()
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

// NOTE: Since LayerScopeTensors is defined in each kernel(layer),
// It should be planned and allocated after the kernels generated.
planLayerScopeTensors(fn_map);
_tensor_builder->allocateLayerScope();

return fn_map;
}

Expand Down Expand Up @@ -255,6 +260,16 @@ FunctionMap BackendContext::generateFunctionMap()
return ret;
}

void BackendContext::planLayerScopeTensors(const FunctionMap &fn_map)
{
// TODO: Register LayerScopeTensors

const auto ctx_data = data();
TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands};
tensor_planner.planLayerScopeTensors(_tensor_builder.get());
return;
}

} // namespace train
} // namespace backend
} // namespace onert
1 change: 1 addition & 0 deletions runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
private:
void planForwardTensors();
void planBackwardTensors();
void planLayerScopeTensors(const FunctionMap &fn_map);

public:
std::shared_ptr<ExternalContext> external_context() { return _external_context; }
Expand Down
43 changes: 43 additions & 0 deletions runtime/onert/backend/train/TensorPlanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,49 @@ ir::OperandIndexSequence TensorPlanner::getOutgoingBackPropSeq(const ir::Operati
return ret;
}

void TensorPlanner::planLayerScopeTensors(TensorBuilder *tensor_builder)
{
// forwading order
const auto f_order = _tgraph.topolSortOperations();
for (const auto &op_index : f_order)
{
if (not tensor_builder->isRegisteredLayerScopeTensor(op_index))
continue;

const auto &indices = tensor_builder->getRegisteredLayerScopeTensorIndices(op_index);
for (const auto &idx : indices)
{
const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx);
if (lt == LayerScopeTensorLifeTime::FORWARD_TO_BACKWARD)
tensor_builder->notifyLayerScopeFirstUse(idx);
}
}

// backwarding order
const auto b_order = _tgraph.essentialBackwardOrder();
for (const auto &op_index : b_order)
{
if (not tensor_builder->isRegisteredLayerScopeTensor(op_index))
continue;

const auto &indices = tensor_builder->getRegisteredLayerScopeTensorIndices(op_index);

for (const auto &idx : indices)
{
const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx);
if (lt == LayerScopeTensorLifeTime::BACKWARD)
tensor_builder->notifyLayerScopeFirstUse(idx);
}
for (const auto &idx : indices)
{
const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx);
if (lt == LayerScopeTensorLifeTime::FORWARD_TO_BACKWARD ||
lt == LayerScopeTensorLifeTime::BACKWARD)
tensor_builder->notifyLayerScopeLastUse(idx);
}
}
}

} // namespace train
} // namespace backend
} // namespace onert
1 change: 1 addition & 0 deletions runtime/onert/backend/train/TensorPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TensorPlanner
void planBackPropTensors(TensorBuilder *tensor_builder);
void planGradientTensors(TensorBuilder *tensor_builder);
void planDisposableBackPropTensors(TensorBuilder *tensor_builder);
void planLayerScopeTensors(TensorBuilder *tensor_builder);

private:
ir::OperandIndexSequence getOutgoingBackPropSeq(const ir::OperationIndex &op_index,
Expand Down

0 comments on commit 466f6c9

Please sign in to comment.