From 414297b4dc8cdd7edc271ab7c5874f26d1166aa2 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Thu, 31 Aug 2023 10:08:15 +0900 Subject: [PATCH] fix test --- compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp | 8 ++++++++ .../luci/partition/src/Nodes/CircleTransposeConv.test.cpp | 1 + compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp | 1 + compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp | 1 + .../luci/service/src/Nodes/CircleTransposeConv.test.cpp | 1 + compiler/tflchef/core/src/Op/TransposeConv.cpp | 5 +++++ 6 files changed, 17 insertions(+) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp index 89ea213e0a0..ae76badc61b 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp @@ -298,6 +298,7 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate) { luci::CircleTransposeConv node; node.padding(luci::Padding::SAME); + node.fusedActivationFunction(luci::FusedActFunc::RELU); EXPECT_TRUE(mock_build(&node)); } @@ -307,3 +308,10 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG) node.padding(luci::Padding::UNDEFINED); EXPECT_FALSE(mock_build(&node)); } + +TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_fused_NEG) +{ + luci::CircleTransposeConv node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp index 68adaad81d8..7dbdfd92f35 100644 --- a/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp +++ b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp @@ -38,6 +38,7 @@ class NodeGraphlet : public NodeGraphletT NodeGraphletT::init(g); _node->padding(luci::Padding::VALID); + _node->fusedActivationFunction(luci::FusedActFunc::RELU); } }; diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp index 265a8398bc2..525fe8d3002 100644 --- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -215,6 +215,7 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) fused_tconv->stride()->h(tconv->stride()->h()); fused_tconv->stride()->w(tconv->stride()->w()); fused_tconv->name(name + "/TransposeConv"); + fused_tconv->fusedActivationFunction(tconv->fusedActivationFunction()); luci::add_origin(fused_tconv, luci::composite_origin( {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)})); diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp index 788353cd8ee..8f6a96f3330 100644 --- a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp @@ -206,6 +206,7 @@ class SimpleTransposeConvGraph transpose_conv->outBackprop(input_1); transpose_conv->filter(filter); transpose_conv->inputSizes(input_sizes); + transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE); if (make_valid) { diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp index 2a1ddb9a833..e9ac6e6ff7f 100644 --- a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp @@ -32,6 +32,7 @@ TEST(CloneNodeTest, clone_TransposeConv) auto cloned_trconv = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_trconv); ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding()); + ASSERT_EQ(node_trconv->fusedActivationFunction(), cloned_trconv->fusedActivationFunction()); } TEST(CloneNodeTest, clone_TransposeConv_padding_NEG) diff --git a/compiler/tflchef/core/src/Op/TransposeConv.cpp b/compiler/tflchef/core/src/Op/TransposeConv.cpp index c9e45271458..58f90c3fb75 100644 --- a/compiler/tflchef/core/src/Op/TransposeConv.cpp +++ b/compiler/tflchef/core/src/Op/TransposeConv.cpp @@ -34,6 +34,11 @@ flatbuffers::Offset TransposeConvChef::value(flatbuffers::FlatBufferBuilde options_builder.add_stride_h(operation.transpose_conv_options().stride_h()); options_builder.add_stride_w(operation.transpose_conv_options().stride_w()); + // TODO enable activation + // auto tflite_activation = as_tflite_activation(operation.sub_options().activation()); + auto tflite_activation = tflite::ActivationFunctionType_NONE; + options_builder.add_fused_activation_function(tflite_activation); + return options_builder.Finish().Union(); }