diff --git a/runtime/onert/core/include/ir/operation/Permute.h b/runtime/onert/core/include/ir/operation/Permute.h index 10f09b9a03b..4866e9d1613 100644 --- a/runtime/onert/core/include/ir/operation/Permute.h +++ b/runtime/onert/core/include/ir/operation/Permute.h @@ -34,28 +34,32 @@ namespace ir namespace operation { +/** + * @brief Class to represent Permute operation + * @note Permute operation reorders the dimensions of a tensor. + * + * This operation is virtual operation, which is not used on real model, but used internally. + * It was introduced to support various model layout (NHWC, NCHW, etc) and backend layout. + * But currently, model layout and backend layout are always same as NHWC. + * So this operation is used for below cases. + * 1) Handle model output buffer's special case + * 1-1) Model output is comes from model constant + * 1-2) Model output is comes from model input + * 1-3) Model output shares tensor with other model output(s) + * 2) Handle shared tensor between different backend + * + * Q) Why name is still 'Permute'? + * A) It is handled as copy operation on compile phase, + * but it can be permute operation if output buffer layout is changed by API call + */ class Permute : public Operation { -public: - enum class Type - { - NHWC_TO_NCHW, - NCHW_TO_NHWC, - COPY - }; - public: void accept(OperationVisitor &v) const override; OpCode opcode() const final { return OpCode::Permute; } public: - Permute(const OperandIndex &input, const OperandIndex &output, Type type); - -public: - Type getPermuteType() const { return _type; } - -private: - Type _type; + Permute(const OperandIndex &input, const OperandIndex &output); }; } // namespace operation diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc index 1448de47354..0e961c62316 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc @@ -41,7 +41,7 @@ void ConstantOutputPass::callback(const ir::OperandIndex &ind, ir::Operand &obj) obj.info().setAsNonConst(); using ir::operation::Permute; - auto permute_obj = std::make_unique(permute_input_ind, ind, Permute::Type::COPY); + auto permute_obj = std::make_unique(permute_input_ind, ind); auto permute_ind = _graph.operations().push(std::move(permute_obj)); permute_input_obj.insertUse(permute_ind); diff --git a/runtime/onert/core/src/compiler/pass/OddOutputPass.cc b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc index e2b3f6111ed..5aabd2aa202 100644 --- a/runtime/onert/core/src/compiler/pass/OddOutputPass.cc +++ b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc @@ -71,7 +71,7 @@ ir::OperandIndex OddOutputPass::insertPermute(ir::OperandIndex ind) auto &output_obj = _graph.operands().at(output_ind); using ir::operation::Permute; - auto permute_obj = std::make_unique(ind, output_ind, Permute::Type::COPY); + auto permute_obj = std::make_unique(ind, output_ind); auto permute_ind = _graph.operations().push(std::move(permute_obj)); output_obj.setDef(permute_ind); diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc index 586e97f9686..d49b88c6800 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc @@ -154,8 +154,7 @@ ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandInde // Insert permute operation to the graph using Permute = ir::operation::Permute; - auto insert_node = - std::make_unique(operand_index, out_operand_index, Permute::Type::COPY); + auto insert_node = std::make_unique(operand_index, out_operand_index); auto node_index = _graph.operations().push(std::move(insert_node)); diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc index 5aa4693adaf..e0f28795b9d 100644 --- a/runtime/onert/core/src/ir/OperationDumper.cc +++ b/runtime/onert/core/src/ir/OperationDumper.cc @@ -268,20 +268,8 @@ void OperationDumper::visit(const Pad &node) void OperationDumper::visit(const Permute &node) { std::string permute_type = "Unknown"; - switch (node.getPermuteType()) - { - case Permute::Type::COPY: - permute_type = "Copy"; - break; - case Permute::Type::NHWC_TO_NCHW: - permute_type = "NHWC to NCHW"; - break; - case Permute::Type::NCHW_TO_NHWC: - permute_type = "NCHW to NHWC"; - break; - } - VERBOSE(LIR) << "* Permute(" + permute_type + ")" << std::endl; + VERBOSE(LIR) << "* " << node.name() << std::endl; VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ")" << std::endl; VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl; } diff --git a/runtime/onert/core/src/ir/operation/Permute.cc b/runtime/onert/core/src/ir/operation/Permute.cc index 813fbaf30ad..77ec42125c7 100644 --- a/runtime/onert/core/src/ir/operation/Permute.cc +++ b/runtime/onert/core/src/ir/operation/Permute.cc @@ -26,8 +26,8 @@ namespace operation void Permute::accept(OperationVisitor &v) const { v.visit(*this); } -Permute::Permute(const OperandIndex &input, const OperandIndex &output, Type type) - : Operation{OperandConstraint::createExact(1u)}, _type{type} +Permute::Permute(const OperandIndex &input, const OperandIndex &output) + : Operation{OperandConstraint::createExact(1u)} { setInputs({input}); setOutputs({output}); diff --git a/runtime/onert/core/src/ir/train/operation/Permute.cc b/runtime/onert/core/src/ir/train/operation/Permute.cc index adc23aa49b7..b1dfcec0336 100644 --- a/runtime/onert/core/src/ir/train/operation/Permute.cc +++ b/runtime/onert/core/src/ir/train/operation/Permute.cc @@ -38,8 +38,7 @@ void Permute::accept(OperationVisitor &v) const { v.visit(*this); } void Permute::accept(TrainableOperationVisitor &v) const { v.visit(*this); } Permute::Permute(const OperationType &operation) - : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0), - operation.getPermuteType()} + : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0)} { // DO NOTHING } diff --git a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc index e3472ec51e9..4a6267b0a34 100644 --- a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc +++ b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc @@ -332,7 +332,7 @@ operation::Pad generatePad() operation::Permute generatePermute() { - return operation::Permute{OperandIndex{1}, OperandIndex{0}, operation::Permute::Type::COPY}; + return operation::Permute{OperandIndex{1}, OperandIndex{0}}; } operation::Pool2D generatePool2D()