diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index d8e9f11f585..3563aa8fe52 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -16,11 +16,18 @@ #include "luci/Pass/FuseAddWithTConvPass.h" +#include "helpers/NodeFiller.h" + #include #include namespace { + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + /** * Fuse Add to TransposeConv if possible * @@ -42,89 +49,74 @@ namespace * * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6 */ -bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) +bool fuse_add_with_tconv(luci::CircleAdd *add) { - // skip if tconv has fused activation - if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) - return false; - // check whether it has bias or not. This optimization works only if it doesn't. - auto bias = dynamic_cast(tconv->bias()); - if (not bias) - return false; - - // get weight of tconv - auto filter = dynamic_cast(tconv->filter()); - if (not filter) - return false; - if (filter->dtype() != loco::DataType::FLOAT32) - return false; - - // get add node - auto tconv_output = loco::succs(tconv); - assert(tconv_output.size() == 1); - auto add = dynamic_cast(*tconv_output.begin()); - if (not add) - return false; - if (add->dtype() != loco::DataType::FLOAT32) - return false; - if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6 && - add->fusedActivationFunction() != luci::FusedActFunc::RELU) - return false; - - // get addition + // Allow Add node only with FLOAT32 data type. + RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32); + // Allow Add node only with specific activations. + RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE || + add->fusedActivationFunction() == luci::FusedActFunc::RELU6 || + add->fusedActivationFunction() == luci::FusedActFunc::RELU); + // Find the pattern of Add(TransposeConv, CircleConst): + luci::CircleTransposeConv *tconv = nullptr; luci::CircleConst *addition = nullptr; - if (add->x() == tconv) - addition = dynamic_cast(add->y()); - else - addition = dynamic_cast(add->x()); + RETURN_FALSE_UNLESS(luci::fill(&tconv, &addition).with_commutative_args_of(add)); - if (not addition) - return false; + RETURN_FALSE_UNLESS(loco::succs(tconv).size() == 1); + + // Skip if tconv has fused activation. + RETURN_FALSE_UNLESS(tconv->fusedActivationFunction() == luci::FusedActFunc::NONE); + // Check whether tconv has bias or not. This optimization works only if it doesn't. + auto bias = dynamic_cast(tconv->bias()); + RETURN_FALSE_UNLESS(bias); + // Get weights of tconv: + auto filter = dynamic_cast(tconv->filter()); + RETURN_FALSE_UNLESS(filter); + RETURN_FALSE_UNLESS(filter->dtype() == loco::DataType::FLOAT32); // addition dim(0) == tconv filter channel dim - if (addition->rank() != 1) - return false; + RETURN_FALSE_UNLESS(addition->rank() == 1); + auto addition_dim = addition->dim(0).value(); auto filter_channel_dim = filter->dim(0).value(); - if (filter_channel_dim != addition_dim) - return false; + RETURN_FALSE_UNLESS(filter_channel_dim == addition_dim); - // fuse addition with transposed conv + // Fuse addition with transposed conv: tconv->bias(addition); if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) { auto name = addition->name(); assert(name.length() > 0); - // separate relu op from add op + // Separate relu op from add op: auto relu = add->graph()->nodes()->create(); relu->features(tconv); relu->name(name + "/Relu6"); luci::add_origin(relu, luci::get_origin(add)); - // remove add node + // Remove add node. replace(add).with(relu); } else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU) { auto name = addition->name(); assert(name.length() > 0); - // separate relu op from add op + // Separate relu op from add op: auto relu = add->graph()->nodes()->create(); relu->features(tconv); relu->name(name + "/Relu"); luci::add_origin(relu, luci::get_origin(add)); - // remove add node + // Remove add node. replace(add).with(relu); } else { + // Remove add node. replace(add).with(tconv); } - // set origin + // Set new origin. luci::add_origin(tconv, luci::get_origin(add)); return true; @@ -140,12 +132,9 @@ bool FuseAddWithTConvPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto tconv = dynamic_cast(node); - if (not tconv) - continue; - - if (fuse_add_with_tconv(tconv)) - changed = true; + if (auto add = dynamic_cast(node)) + if (fuse_add_with_tconv(add)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp index 8748d73efcd..3bf7d112aa2 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp @@ -16,9 +16,196 @@ #include "luci/Pass/FuseAddWithTConvPass.h" +#include "helpers/CreateCircleConst.h" + +#include +#include + #include -TEST(FuseAddWithTConvPassTest, name) +#define ADD_VAL 5.0f +namespace +{ + +using namespace luci::test; + +/** + * Graph for this test + * + * BEFORE (without extra_successor) + * + * | + * [CircleConst] [CircleTransposeConv] + * \ | + * [CircleAdd w/ Relu] + * | + * + * BEFORE (with extra_successor) + * + * | + * [CircleConst] [CircleTransposeConv] + * \ | | + * [CircleAdd w/ Relu] [extra FC] + * | | + * + * AFTER (if pass was successful) + * + * | + * [CircleConst as bias] | + * \ | + * [CircleTransposeConv] + * | + * ([CircleRelu/Relu]) + * | + * + */ +class TConvAddGraphlet +{ +public: + void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias, + bool extra_successor) + { + _tconv = g->nodes()->create(); + + std::vector input_sizes_val = {1, 4, 4, 1}; + _tconv_i = luci::create_const_node(g, loco::DataType::FLOAT32, {4}, input_sizes_val); + _tconv->inputSizes(_tconv_i); + + std::vector filter_val(18); + for (uint32_t i = 0; i < 18; i++) + filter_val.at(i) = i; + + _tconv_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 3, 3, 2}, filter_val); + _tconv->filter(_tconv_f); + + if (use_bias) + { + std::vector bias_val(1, 3.0f); + _tconv_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, bias_val); + } + else + { + // Create CircleOutputExclude -- no bias + _tconv_b = g->nodes()->create(); + } + _tconv->bias(_tconv_b); + + _tconv->padding(luci::Padding::VALID); + auto _stride = _tconv->stride(); + _stride->w(1); + _stride->h(1); + _tconv->fusedActivationFunction(tconv_activation); + _tconv->dtype(loco::DataType::FLOAT32); + _tconv->shape({1, 4, 4, 1}); + _tconv->name("tconv"); + + if (extra_successor) + { + _extra_succ = g->nodes()->create(); + // Set previous TConv as input to bump number of successors for it: + _extra_succ->input(_tconv); + std::vector weights_val(8); + _extra_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 8}, weights_val); + _extra_succ->weights(_extra_f); + _extra_succ->bias(nullptr); + _extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE); + _extra_succ->dtype(loco::DataType::FLOAT32); + _extra_succ->shape({1, 4, 4, 1}); + _extra_succ->name("extra_fc"); + } + + std::vector add_values(1, ADD_VAL); + _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, add_values); + _add_c->name("const_c"); + + _add = g->nodes()->create(); + _add->x(_tconv); + _add->y(_add_c); + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + _add->dtype(loco::DataType::FLOAT32); + _add->shape({1, 4, 4, 1}); + + _add->name("add"); + } + +protected: + luci::CircleTransposeConv *_tconv = nullptr; + luci::CircleConst *_tconv_i = nullptr; + luci::CircleConst *_tconv_f = nullptr; + luci::CircleNode *_tconv_b = nullptr; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_add_c = nullptr; + luci::CircleFullyConnected *_extra_succ = nullptr; + luci::CircleConst *_extra_f = nullptr; +}; + +class FuseAddWithTConvTestGraph : public TestIOGraph, public TConvAddGraphlet +{ +public: + void init(luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor) + { + TestIOGraph::init({1, 2, 2, 2}, {1, 4, 4, 1}); + TConvAddGraphlet::init(g(), tconv_activation, use_bias, extra_successor); + + _tconv->outBackprop(input()); + + output()->from(_add); + } +}; + +class FuseAddWithTConvPassTest : public ::testing::Test +{ +public: + FuseAddWithTConvTestGraph g; + luci::FuseAddWithTConvPass pass; +}; + +} // namespace + +TEST_F(FuseAddWithTConvPassTest, tconv_add_fuse) +{ + g.init(luci::FusedActFunc::NONE, false /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto relu = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, relu); + EXPECT_STREQ(relu->name().c_str(), "const_c/Relu"); + + auto tconv = dynamic_cast(relu->features()); + EXPECT_NE(nullptr, tconv); + + auto bias = loco::must_cast(tconv->bias()); + EXPECT_NE(nullptr, bias); + + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(ADD_VAL, bias->at(i)); + } +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_bias_NEG) +{ + g.init(luci::FusedActFunc::NONE, true /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_activation_NEG) +{ + g.init(luci::FusedActFunc::RELU, false /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_extra_successor_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* use_bias */, true /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, name) { luci::FuseAddWithTConvPass pass; auto const name = pass.name();