Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci] Support TransposeConv activation #11467

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/luci/export/src/CircleBuiltinTypesExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ class BuiltinOptionsExtractor final
flatbuffers::Offset<void> visit(luci::CircleTransposeConv *node)
{
return circle::CreateTransposeConvOptions(_builder, getOpPadding(node->padding()),
node->stride()->w(), node->stride()->h())
node->stride()->w(), node->stride()->h(),
to_circle_actfunc(node->fusedActivationFunction()))
.Union();
}
flatbuffers::Offset<void> visit(luci::CircleUnidirectionalSequenceLSTM *node)
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT
node->padding(luci_padding(options->padding));
node->stride()->w(options->stride_w);
node->stride()->h(options->stride_h);
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));

return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace luci
*/
class CircleTransposeConv final
: public FixedArityNode<4, CircleNodeImpl<CircleOpcode::TRANSPOSE_CONV>>,
public CircleNodeMixin<CircleNodeTrait::FusedActFunc>,
public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
Expand Down
8 changes: 8 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate)
{
luci::CircleTransposeConv node;
node.padding(luci::Padding::SAME);
node.fusedActivationFunction(luci::FusedActFunc::RELU);
EXPECT_TRUE(mock_build(&node));
}

Expand All @@ -307,3 +308,10 @@ TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG)
node.padding(luci::Padding::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_fused_NEG)
{
luci::CircleTransposeConv node;
node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}
3 changes: 3 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ bool CircleTransposeConvSummaryBuilder::validate(const luci::CircleNode *node)
auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
if (transpose_conv->padding() == luci::Padding::UNDEFINED)
return false;
if (transpose_conv->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return false;

return true;
}
Expand All @@ -1034,6 +1036,7 @@ void CircleTransposeConvSummaryBuilder::build_attributes(const luci::CircleNode
auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
s.args().append("stride(h,w)", to_str(transpose_conv->stride()));
s.args().append("padding", to_str(transpose_conv->padding()));
s.args().append("fused_activation_function", to_str(transpose_conv->fusedActivationFunction()));
}

std::vector<std::string>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class NodeGraphlet : public NodeGraphletT<luci::CircleTransposeConv>
NodeGraphletT<luci::CircleTransposeConv>::init(g);

_node->padding(luci::Padding::VALID);
_node->fusedActivationFunction(luci::FusedActFunc::RELU);
}
};

Expand Down
3 changes: 3 additions & 0 deletions compiler/luci/pass/src/FuseAddWithTConvPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ namespace
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
{
// skip if tconv has fused activation
if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;
// check whether it has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
if (not bias)
Expand Down
6 changes: 6 additions & 0 deletions compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
return false;
if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
return false;
// skip if tconv has fused activation
if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;

// check scale and shift constant attributes
// TODO maybe rank check is not needed
Expand Down Expand Up @@ -215,6 +218,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
fused_tconv->stride()->h(tconv->stride()->h());
fused_tconv->stride()->w(tconv->stride()->w());
fused_tconv->name(name + "/TransposeConv");
// TODO set activation from Add and remove adding following Relu/Relu6 Op
// when all of our backends supports fused activation of TransposeConv
fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE);
luci::add_origin(fused_tconv,
luci::composite_origin(
{luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class SimpleTransposeConvGraph
transpose_conv->outBackprop(input_1);
transpose_conv->filter(filter);
transpose_conv->inputSizes(input_sizes);
transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE);

if (make_valid)
{
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/service/src/Nodes/CircleTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTransposeConv
cloned->padding(node->padding());
cloned->stride()->h(node->stride()->h());
cloned->stride()->w(node->stride()->w());
cloned->fusedActivationFunction(node->fusedActivationFunction());
}
return cloned;
}
Expand Down
12 changes: 12 additions & 0 deletions compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ TEST(CloneNodeTest, clone_TransposeConv)
auto cloned_trconv = dynamic_cast<luci::CircleTransposeConv *>(cloned);
ASSERT_NE(nullptr, cloned_trconv);
ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding());
ASSERT_EQ(node_trconv->fusedActivationFunction(), cloned_trconv->fusedActivationFunction());
}

TEST(CloneNodeTest, clone_TransposeConv_padding_NEG)
Expand All @@ -44,3 +45,14 @@ TEST(CloneNodeTest, clone_TransposeConv_padding_NEG)
auto cloned = luci::clone_node(node_trconv, gc.get());
ASSERT_EQ(nullptr, cloned);
}

TEST(CloneNodeTest, clone_TransposeConv_fAF_NEG)
{
auto g = loco::make_graph();
auto node_trconv = g->nodes()->create<luci::CircleTransposeConv>();
node_trconv->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);

auto gc = loco::make_graph();
auto cloned = luci::clone_node(node_trconv, gc.get());
ASSERT_EQ(nullptr, cloned);
}