diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp index f947951415a..2574cbea95a 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp @@ -31,40 +31,48 @@ namespace if (not(cond)) \ return false; -bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add, - luci::CircleFullyConnected **fc, luci::CircleConst **addition, - luci::CircleConst **weights, luci::CircleNode **bias) +struct PatternNodes { - RETURN_FALSE_UNLESS((*add)->dtype() == dtype); + luci::CircleFullyConnected *fc = nullptr; + // addition must be const + luci::CircleConst *addition = nullptr; + luci::CircleConst *weights = nullptr; + luci::CircleNode *bias = nullptr; +}; - RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add)); +bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add, + PatternNodes &nodes) +{ + RETURN_FALSE_UNLESS(add.dtype() == dtype); + + RETURN_FALSE_UNLESS(luci::fill(&nodes.fc, &nodes.addition).with_commutative_args_of(&add)); // Check if fc has only one successor to limit possible weights size increase. - RETURN_FALSE_UNLESS(loco::succs(*fc).size() == 1); - RETURN_FALSE_UNLESS((*fc)->dtype() == dtype); - RETURN_FALSE_UNLESS((*fc)->fusedActivationFunction() == luci::FusedActFunc::NONE); + RETURN_FALSE_UNLESS(loco::succs(nodes.fc).size() == 1); + RETURN_FALSE_UNLESS(nodes.fc->dtype() == dtype); + RETURN_FALSE_UNLESS(nodes.fc->fusedActivationFunction() == luci::FusedActFunc::NONE); - *weights = dynamic_cast((*fc)->weights()); - RETURN_FALSE_UNLESS(*weights); + nodes.weights = dynamic_cast(nodes.fc->weights()); + RETURN_FALSE_UNLESS(nodes.weights); - RETURN_FALSE_UNLESS((*addition)->dtype() == dtype); + RETURN_FALSE_UNLESS(nodes.addition->dtype() == dtype); - auto rank = (*addition)->rank(); + auto rank = (nodes.addition)->rank(); // TODO Support scalar addition RETURN_FALSE_UNLESS(rank != 0); for (uint32_t i = 0; i < rank - 1; i++) { - RETURN_FALSE_UNLESS((*addition)->dim(i).value() == 1); + RETURN_FALSE_UNLESS(nodes.addition->dim(i).value() == 1); } // Check the last dimesion of addition is the same with the number of neurons of FC - RETURN_FALSE_UNLESS((*addition)->dim(rank - 1) == (*weights)->dim(0)); + RETURN_FALSE_UNLESS(nodes.addition->dim(rank - 1) == nodes.weights->dim(0)); // We only support (1) constant bias (2) no bias // If bias is neither (1) nor (2), it would be a feature map - *bias = loco::must_cast((*fc)->bias()); - RETURN_FALSE_UNLESS((*bias)->opcode() == luci::CircleOpcode::CIRCLECONST or - (*bias)->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE); + nodes.bias = loco::must_cast(nodes.fc->bias()); + RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST or + nodes.bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE); return true; } @@ -87,19 +95,14 @@ bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add */ bool fuse_add_with_fc(luci::CircleAdd *add) { - luci::CircleFullyConnected *fc = nullptr; - // addition must be const - luci::CircleConst *addition = nullptr; - luci::CircleConst *weights = nullptr; - luci::CircleNode *bias = nullptr; + PatternNodes nodes; - RETURN_FALSE_UNLESS( - fc_with_add_pattern_check(loco::DataType::FLOAT32, &add, &fc, &addition, &weights, &bias)); + RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::FLOAT32, *add, nodes)); - auto fused_bias = luci::clone(addition); + auto fused_bias = luci::clone(nodes.addition); // Add existing bias values - if (auto const_bias = dynamic_cast(fc->bias())) + if (auto const_bias = dynamic_cast(nodes.fc->bias())) { RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); @@ -113,15 +116,15 @@ bool fuse_add_with_fc(luci::CircleAdd *add) // where N is weights->dim(0). // The shape is normalized to [N] to become the bias of FC fused_bias->rank(1); - fused_bias->dim(0) = weights->dim(0); + fused_bias->dim(0) = nodes.weights->dim(0); - fc->bias(fused_bias); - fc->fusedActivationFunction(add->fusedActivationFunction()); + nodes.fc->bias(fused_bias); + nodes.fc->fusedActivationFunction(add->fusedActivationFunction()); // set origin - luci::add_origin(fc, luci::get_origin(add)); + luci::add_origin(nodes.fc, luci::get_origin(add)); - replace(add).with(fc); + replace(add).with(nodes.fc); return true; } @@ -144,22 +147,17 @@ luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len) bool fuse_add_with_s16_fc(luci::CircleAdd *add) { - luci::CircleFullyConnected *fc = nullptr; - // addition must be const - luci::CircleConst *addition = nullptr; - luci::CircleConst *weights = nullptr; - luci::CircleNode *bias = nullptr; + PatternNodes nodes; - RETURN_FALSE_UNLESS( - fc_with_add_pattern_check(loco::DataType::S16, &add, &fc, &addition, &weights, &bias)); + RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::S16, *add, nodes)); // If bias is const, its dtype must be s64 - RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST and - bias->dtype() == loco::DataType::S64); + RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST and + nodes.bias->dtype() == loco::DataType::S64); - const auto last_dim = addition->dim(addition->rank() - 1).value(); + const auto last_dim = nodes.addition->dim(nodes.addition->rank() - 1).value(); - const auto addition_qparam = get_qparam(addition, last_dim); + const auto addition_qparam = get_qparam(nodes.addition, last_dim); RETURN_FALSE_UNLESS(addition_qparam); std::vector fp32_bias(last_dim); @@ -168,12 +166,12 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) auto scale = addition_qparam->scale.at(i); RETURN_FALSE_UNLESS(addition_qparam->zerop.at(i) == 0); - auto val = addition->at(i); + auto val = nodes.addition->at(i); fp32_bias[i] = val * scale; } // Add existing bias values - if (auto const_bias = dynamic_cast(bias)) + if (auto const_bias = dynamic_cast(nodes.bias)) { const auto bias_qparam = get_qparam(const_bias, last_dim); RETURN_FALSE_UNLESS(bias_qparam); @@ -191,14 +189,14 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) const auto add_qparam = get_qparam(add, 1); RETURN_FALSE_UNLESS(add_qparam); - auto input = loco::must_cast(fc->input()); + auto input = loco::must_cast(nodes.fc->input()); const auto input_qparam = get_qparam(input, 1); RETURN_FALSE_UNLESS(input_qparam); - const auto weights_qparam = get_qparam(weights, last_dim); + const auto weights_qparam = get_qparam(nodes.weights, last_dim); RETURN_FALSE_UNLESS(weights_qparam); - auto fused_bias = luci::clone(addition); + auto fused_bias = luci::clone(nodes.addition); fused_bias->dtype(loco::DataType::S64); fused_bias->size(last_dim); @@ -231,8 +229,8 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) fused_bias->quantparam(std::move(bias_qparam)); // In-place update. This works because fc is guaranteed to have a single successor - fc->bias(fused_bias); - fc->fusedActivationFunction(add->fusedActivationFunction()); + nodes.fc->bias(fused_bias); + nodes.fc->fusedActivationFunction(add->fusedActivationFunction()); auto qparam = std::make_unique(); { @@ -240,12 +238,12 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) qparam->zerop.push_back(add_qparam->scale.at(0)); } - fc->quantparam(std::move(qparam)); + nodes.fc->quantparam(std::move(qparam)); // set origin - luci::add_origin(fc, luci::get_origin(add)); + luci::add_origin(nodes.fc, luci::get_origin(add)); - replace(add).with(fc); + replace(add).with(nodes.fc); return true; }