diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index cd9face54d6..c724f832b94 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -16,12 +16,12 @@ #include "luci/Pass/FuseMulWithFullyConnectedPass.h" +#include "helpers/NodeFiller.h" + #include #include #include -#include - namespace { @@ -107,10 +107,19 @@ luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleCon * | * */ -bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) +bool fuse_mul_with_fc(luci::CircleMul *mul) { // Sanity check: - RETURN_FALSE_UNLESS(fc); + RETURN_FALSE_UNLESS(mul); + // Allow Mul node only with FLOAT32 data type: + RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); + // Check if any FC node connects to Mul. + // Find the pattern of Mul(FC, CircleConst): + luci::CircleFullyConnected *fc = nullptr; + luci::CircleConst *multiplication = nullptr; + RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul)); + // Make sure that FullyConnected has only one successor: + RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1); // Allow only FLOAT32 data type: RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); // Allow only without activation functions as values are going to @@ -119,18 +128,6 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check for weights being Constant: auto weights = dynamic_cast(fc->weights()); RETURN_FALSE_UNLESS(weights); - // Get Mul node: - auto fc_output = loco::succs(fc); - // Make sure that FullyConnected has only one child: - RETURN_FALSE_UNLESS(fc_output.size() == 1); - auto mul = dynamic_cast(*fc_output.begin()); - RETURN_FALSE_UNLESS(mul); - // Allow Mul node only with FLOAT32 data type: - RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); - // Get multiplication Constant (here: the second input besides weights): - auto multiplication = mul->x() == fc ? dynamic_cast(mul->y()) - : dynamic_cast(mul->x()); - RETURN_FALSE_UNLESS(multiplication); // Get rank of multiplication: auto rank = multiplication->rank(); // Check that all dimensions are ones, checks broadcast capabilites. @@ -197,14 +194,14 @@ bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto fc = dynamic_cast(node); - if (not fc) + auto mul = dynamic_cast(node); + if (not mul) continue; - switch (fc->dtype()) + switch (mul->dtype()) { case loco::DataType::FLOAT32: - if (fuse_mul_with_fc(fc)) + if (fuse_mul_with_fc(mul)) changed = true; break; default: diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 821f4ff3d5c..a4f9d6bf087 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -34,13 +34,22 @@ using namespace luci::test; /** * Graph for this test * - * BEFORE + * BEFORE (without extra_fc_successor) * * [FC] * | * [Mul w/ Relu] * - * AFTER + * BEFORE (with extra_fc_successor) + * + * [FC] + * | + * |------------------- + * | | + * | | + * [Mul w/ Relu] [other FC] + * + * AFTER (if pass applied) * * [FC w/ Relu] (weights and bias updated) * @@ -48,7 +57,8 @@ using namespace luci::test; class FCMulGraphlet { public: - void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) + void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) { _fc = g->nodes()->create(); @@ -79,6 +89,22 @@ class FCMulGraphlet _fc->shape({1, DIM_ONE}); _fc->name("fc"); + if (extra_successor) + { + _extra_succ = g->nodes()->create(); + // Set previous FC as input to bump number of successors for it: + _extra_succ->input(_fc); + std::vector weights_val(DIM_ONE * DIM_TWO); + _extra_f = + luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + _extra_succ->weights(_extra_f); + _extra_succ->bias(nullptr); + _extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE); + _extra_succ->dtype(loco::DataType::FLOAT32); + _extra_succ->shape({1, DIM_ONE}); + _extra_succ->name("extra_fc"); + } + std::vector mul_values; if (is_mul_scalar) @@ -128,15 +154,18 @@ class FCMulGraphlet luci::CircleConst *_fc_f = nullptr; luci::CircleNode *_fc_b = nullptr; luci::CircleConst *_mul_c = nullptr; + luci::CircleFullyConnected *_extra_succ = nullptr; + luci::CircleConst *_extra_f = nullptr; }; class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet { public: - void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) + void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) { TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); - FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias); + FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias, extra_successor); _fc->input(input()); @@ -155,7 +184,8 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -184,7 +214,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) { - g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -213,7 +244,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -238,7 +270,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); // Bias cannot be fused as it's passed as feature map. g.to_fm_bias(); @@ -248,16 +281,26 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) { - g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(false, pass.run(g.g())); } TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_null_weights_NEG) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); g.fc()->weights(nullptr); EXPECT_EQ(false, pass.run(g.g())); } + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_extra_successor_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + true /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +}