From 3327631c0787711c3eb328ab98106ea56dd616e0 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Mon, 4 Sep 2023 03:47:07 +0000 Subject: [PATCH] [luci] Support TransposeConv activation This will revise to support TransposeConv activation. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- .../luci/export/src/CircleBuiltinTypesExtractor.h | 3 ++- .../luci/import/src/Nodes/CircleTransposeConv.cpp | 1 + .../lang/include/luci/IR/Nodes/CircleTransposeConv.h | 1 + .../luci/logex/src/CircleNodeSummaryBuilder.test.cpp | 8 ++++++++ .../luci/logex/src/CircleNodeSummaryBuilders.cpp | 3 +++ .../partition/src/Nodes/CircleTransposeConv.test.cpp | 1 + compiler/luci/pass/src/FuseAddWithTConvPass.cpp | 3 +++ .../luci/pass/src/FuseBatchNormWithTConvPass.cpp | 6 ++++++ .../luci/pass/src/QuantizePreCheckerPass.test.cpp | 1 + .../luci/service/src/Nodes/CircleTransposeConv.cpp | 1 + .../service/src/Nodes/CircleTransposeConv.test.cpp | 12 ++++++++++++ 11 files changed, 39 insertions(+), 1 deletion(-) diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h index 0b2d2cded2b..811373ffe34 100644 --- a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -485,7 +485,8 @@ class BuiltinOptionsExtractor final flatbuffers::Offset visit(luci::CircleTransposeConv *node) { return circle::CreateTransposeConvOptions(_builder, getOpPadding(node->padding()), - node->stride()->w(), node->stride()->h()) + node->stride()->w(), node->stride()->h(), + to_circle_actfunc(node->fusedActivationFunction())) .Union(); } flatbuffers::Offset visit(luci::CircleUnidirectionalSequenceLSTM *node) diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp index 01a28cb8af4..62326f435a1 100644 --- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp @@ -74,6 +74,7 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT node->padding(luci_padding(options->padding)); node->stride()->w(options->stride_w); node->stride()->h(options->stride_h); + node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); return node; } diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h index 5ae41c0c422..8c6f04a588f 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h @@ -35,6 +35,7 @@ namespace luci */ class CircleTransposeConv final : public FixedArityNode<4, CircleNodeImpl>, + public CircleNodeMixin, public CircleNodeMixin { public: diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp index 89ea213e0a0..ae76badc61b 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp @@ -298,6 +298,7 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate) { luci::CircleTransposeConv node; node.padding(luci::Padding::SAME); + node.fusedActivationFunction(luci::FusedActFunc::RELU); EXPECT_TRUE(mock_build(&node)); } @@ -307,3 +308,10 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG) node.padding(luci::Padding::UNDEFINED); EXPECT_FALSE(mock_build(&node)); } + +TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_fused_NEG) +{ + luci::CircleTransposeConv node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index d18105c1697..aba6a86815f 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -1018,6 +1018,8 @@ bool CircleTransposeConvSummaryBuilder::validate(const luci::CircleNode *node) auto transpose_conv = loco::must_cast(node); if (transpose_conv->padding() == luci::Padding::UNDEFINED) return false; + if (transpose_conv->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; return true; } @@ -1034,6 +1036,7 @@ void CircleTransposeConvSummaryBuilder::build_attributes(const luci::CircleNode auto transpose_conv = loco::must_cast(node); s.args().append("stride(h,w)", to_str(transpose_conv->stride())); s.args().append("padding", to_str(transpose_conv->padding())); + s.args().append("fused_activation_function", to_str(transpose_conv->fusedActivationFunction())); } std::vector diff --git a/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp index 68adaad81d8..7dbdfd92f35 100644 --- a/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp +++ b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp @@ -38,6 +38,7 @@ class NodeGraphlet : public NodeGraphletT NodeGraphletT::init(g); _node->padding(luci::Padding::VALID); + _node->fusedActivationFunction(luci::FusedActFunc::RELU); } }; diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index 852bc8b63a3..d8e9f11f585 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -44,6 +44,9 @@ namespace */ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) { + // skip if tconv has fused activation + if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; // check whether it has bias or not. This optimization works only if it doesn't. auto bias = dynamic_cast(tconv->bias()); if (not bias) diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp index 265a8398bc2..919ce6edcd1 100644 --- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -87,6 +87,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) return false; if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul)) return false; + // skip if tconv has fused activation + if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; // check scale and shift constant attributes // TODO maybe rank check is not needed @@ -215,6 +218,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) fused_tconv->stride()->h(tconv->stride()->h()); fused_tconv->stride()->w(tconv->stride()->w()); fused_tconv->name(name + "/TransposeConv"); + // TODO set activation from Add and remove adding following Relu/Relu6 Op + // when all of our backends supports fused activation of TransposeConv + fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE); luci::add_origin(fused_tconv, luci::composite_origin( {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)})); diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp index 788353cd8ee..8f6a96f3330 100644 --- a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp @@ -206,6 +206,7 @@ class SimpleTransposeConvGraph transpose_conv->outBackprop(input_1); transpose_conv->filter(filter); transpose_conv->inputSizes(input_sizes); + transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE); if (make_valid) { diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp index 5d2fa482d19..73aad2eb6df 100644 --- a/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp @@ -30,6 +30,7 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleTransposeConv cloned->padding(node->padding()); cloned->stride()->h(node->stride()->h()); cloned->stride()->w(node->stride()->w()); + cloned->fusedActivationFunction(node->fusedActivationFunction()); } return cloned; } diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp index 29a656c0373..e9ac6e6ff7f 100644 --- a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp @@ -32,6 +32,7 @@ TEST(CloneNodeTest, clone_TransposeConv) auto cloned_trconv = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_trconv); ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding()); + ASSERT_EQ(node_trconv->fusedActivationFunction(), cloned_trconv->fusedActivationFunction()); } TEST(CloneNodeTest, clone_TransposeConv_padding_NEG) @@ -44,3 +45,14 @@ TEST(CloneNodeTest, clone_TransposeConv_padding_NEG) auto cloned = luci::clone_node(node_trconv, gc.get()); ASSERT_EQ(nullptr, cloned); } + +TEST(CloneNodeTest, clone_TransposeConv_fAF_NEG) +{ + auto g = loco::make_graph(); + auto node_trconv = g->nodes()->create(); + node_trconv->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_trconv, gc.get()); + ASSERT_EQ(nullptr, cloned); +}