Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
seanshpark committed Aug 31, 2023
1 parent 82ecc37 commit f7ad16d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate)
{
luci::CircleTransposeConv node;
node.padding(luci::Padding::SAME);
node.fusedActivationFunction(luci::FusedActFunc::RELU);
EXPECT_TRUE(mock_build(&node));
}

Expand All @@ -307,3 +308,10 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG)
node.padding(luci::Padding::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_fused_NEG)
{
luci::CircleTransposeConv node;
node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class NodeGraphlet : public NodeGraphletT<luci::CircleTransposeConv>
NodeGraphletT<luci::CircleTransposeConv>::init(g);

_node->padding(luci::Padding::VALID);
_node->fusedActivationFunction(luci::FusedActFunc::RELU);
}
};

Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
fused_tconv->stride()->h(tconv->stride()->h());
fused_tconv->stride()->w(tconv->stride()->w());
fused_tconv->name(name + "/TransposeConv");
fused_tconv->fusedActivationFunction(tconv->fusedActivationFunction());
luci::add_origin(fused_tconv,
luci::composite_origin(
{luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class SimpleTransposeConvGraph
transpose_conv->outBackprop(input_1);
transpose_conv->filter(filter);
transpose_conv->inputSizes(input_sizes);
transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE);

if (make_valid)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ TEST(CloneNodeTest, clone_TransposeConv)
auto cloned_trconv = dynamic_cast<luci::CircleTransposeConv *>(cloned);
ASSERT_NE(nullptr, cloned_trconv);
ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding());
ASSERT_EQ(node_trconv->fusedActivationFunction(), cloned_trconv->fusedActivationFunction());
}

TEST(CloneNodeTest, clone_TransposeConv_padding_NEG)
Expand Down
5 changes: 5 additions & 0 deletions compiler/tflchef/core/src/Op/TransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ flatbuffers::Offset<void> TransposeConvChef::value(flatbuffers::FlatBufferBuilde
options_builder.add_stride_h(operation.transpose_conv_options().stride_h());
options_builder.add_stride_w(operation.transpose_conv_options().stride_w());

// TODO enable activation
// auto tflite_activation = as_tflite_activation(operation.sub_options().activation());
auto tflite_activation = tflite::ActivationFunctionType_NONE;
options_builder.add_fused_activation_function(tflite_activation);

return options_builder.Finish().Union();
}

Expand Down

0 comments on commit f7ad16d

Please sign in to comment.