Skip to content

Commit

Permalink
[onert] Unify interface for generating tensors (#13631)
Browse files Browse the repository at this point in the history
This commit unifies interface for generating tensors.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Aug 9, 2024
1 parent 45fbbe2 commit 484774f
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 24 deletions.
9 changes: 1 addition & 8 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,9 @@ getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph,
backend::ITensorRegistry *BackendContext::genTensors()
{
planForwardTensors();

_tensor_builder->allocate();

return _tensor_registry.get();
}

backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
planBackwardTensors();

_tensor_builder->allocate();
_tensor_builder->allocateBackward();

return _tensor_registry.get();
Expand Down
1 change: 0 additions & 1 deletion runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext

public:
backend::ITensorRegistry *genTensors() override;
backend::train::ITensorRegistry *genTrainingTensors() override;

private:
void planForwardTensors();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class TrainableBackendContext

std::shared_ptr<ITensorRegistry> tensor_registry() { return _tensor_registry; }

virtual ITensorRegistry *genTrainingTensors() = 0;
virtual backend::ITensorRegistry *genTensors() = 0;
virtual FunctionMap genKernels() = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,13 @@ namespace train

backend::ITensorRegistry *BackendContext::genTensors()
{
// For now, there is no need to generate tensors for forwarding.
// For now, there is no need to generate tensors for forwarding and backwarding.
// builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
// `Permute`: Tensor generation is not required.
// `IF`, `WHILE`: Not supported yet
return tensor_registry().get();
}

backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
// For now, there is no need to generate tensors for backwarding.
return tensor_registry().get();
}

backend::train::FunctionMap BackendContext::genKernels()
{
backend::train::FunctionMap ret;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class BackendContext : public backend::train::TrainableBackendContext
}

backend::ITensorRegistry *genTensors() override;
backend::train::ITensorRegistry *genTrainingTensors() override;

public:
backend::train::FunctionMap genKernels() override;
Expand Down
6 changes: 0 additions & 6 deletions runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,12 +746,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
pair.second->genTensors();
}

for (auto &&pair : tbackend_contexts)
{
auto tctx = pair.second.get();
tctx->genTrainingTensors();
}

prepareMigrantTensors(*lowered_graph, tbackend_contexts);

// Give some runtime objects to builtin KernelGenerator
Expand Down

0 comments on commit 484774f

Please sign in to comment.