-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846
base: master
Are you sure you want to change the base?
Changes from 1 commit
f1c258d
6c13ac5
7076ebd
d823de1
abf1acf
d8e1245
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,40 +31,48 @@ namespace | |
if (not(cond)) \ | ||
return false; | ||
|
||
bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add, | ||
luci::CircleFullyConnected **fc, luci::CircleConst **addition, | ||
luci::CircleConst **weights, luci::CircleNode **bias) | ||
struct PatternNodes | ||
{ | ||
RETURN_FALSE_UNLESS((*add)->dtype() == dtype); | ||
luci::CircleFullyConnected *fc = nullptr; | ||
// addition must be const | ||
luci::CircleConst *addition = nullptr; | ||
luci::CircleConst *weights = nullptr; | ||
luci::CircleNode *bias = nullptr; | ||
}; | ||
|
||
RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add)); | ||
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of passing template <typename DT>
bool fc_with_add_pattern_check(const luci::CircleAdd &add, ? |
||
PatternNodes &nodes) | ||
{ | ||
RETURN_FALSE_UNLESS(add.dtype() == dtype); | ||
|
||
RETURN_FALSE_UNLESS(luci::fill(&nodes.fc, &nodes.addition).with_commutative_args_of(&add)); | ||
|
||
// Check if fc has only one successor to limit possible weights size increase. | ||
RETURN_FALSE_UNLESS(loco::succs(*fc).size() == 1); | ||
RETURN_FALSE_UNLESS((*fc)->dtype() == dtype); | ||
RETURN_FALSE_UNLESS((*fc)->fusedActivationFunction() == luci::FusedActFunc::NONE); | ||
RETURN_FALSE_UNLESS(loco::succs(nodes.fc).size() == 1); | ||
RETURN_FALSE_UNLESS(nodes.fc->dtype() == dtype); | ||
RETURN_FALSE_UNLESS(nodes.fc->fusedActivationFunction() == luci::FusedActFunc::NONE); | ||
|
||
*weights = dynamic_cast<luci::CircleConst *>((*fc)->weights()); | ||
RETURN_FALSE_UNLESS(*weights); | ||
nodes.weights = dynamic_cast<luci::CircleConst *>(nodes.fc->weights()); | ||
RETURN_FALSE_UNLESS(nodes.weights); | ||
|
||
RETURN_FALSE_UNLESS((*addition)->dtype() == dtype); | ||
RETURN_FALSE_UNLESS(nodes.addition->dtype() == dtype); | ||
|
||
auto rank = (*addition)->rank(); | ||
auto rank = (nodes.addition)->rank(); | ||
// TODO Support scalar addition | ||
RETURN_FALSE_UNLESS(rank != 0); | ||
|
||
for (uint32_t i = 0; i < rank - 1; i++) | ||
{ | ||
RETURN_FALSE_UNLESS((*addition)->dim(i).value() == 1); | ||
RETURN_FALSE_UNLESS(nodes.addition->dim(i).value() == 1); | ||
} | ||
// Check the last dimesion of addition is the same with the number of neurons of FC | ||
RETURN_FALSE_UNLESS((*addition)->dim(rank - 1) == (*weights)->dim(0)); | ||
RETURN_FALSE_UNLESS(nodes.addition->dim(rank - 1) == nodes.weights->dim(0)); | ||
|
||
// We only support (1) constant bias (2) no bias | ||
// If bias is neither (1) nor (2), it would be a feature map | ||
*bias = loco::must_cast<luci::CircleNode *>((*fc)->bias()); | ||
RETURN_FALSE_UNLESS((*bias)->opcode() == luci::CircleOpcode::CIRCLECONST or | ||
(*bias)->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE); | ||
nodes.bias = loco::must_cast<luci::CircleNode *>(nodes.fc->bias()); | ||
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST or | ||
nodes.bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE); | ||
|
||
return true; | ||
} | ||
|
@@ -87,19 +95,14 @@ bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add | |
*/ | ||
bool fuse_add_with_fc(luci::CircleAdd *add) | ||
{ | ||
luci::CircleFullyConnected *fc = nullptr; | ||
// addition must be const | ||
luci::CircleConst *addition = nullptr; | ||
luci::CircleConst *weights = nullptr; | ||
luci::CircleNode *bias = nullptr; | ||
PatternNodes nodes; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems you want to find members of this |
||
|
||
RETURN_FALSE_UNLESS( | ||
fc_with_add_pattern_check(loco::DataType::FLOAT32, &add, &fc, &addition, &weights, &bias)); | ||
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::FLOAT32, *add, nodes)); | ||
|
||
auto fused_bias = luci::clone(addition); | ||
auto fused_bias = luci::clone(nodes.addition); | ||
|
||
// Add existing bias values | ||
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias())) | ||
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.fc->bias())) | ||
{ | ||
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you check this inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, this check is specific to float32.. It would be better to use a pattern class (PTAL how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jinevening could this refactor be done in separate PR? In my opinion, it would be better to push it as-is and create another PR that will introduce pattern search like |
||
|
||
|
@@ -113,15 +116,15 @@ bool fuse_add_with_fc(luci::CircleAdd *add) | |
// where N is weights->dim(0). | ||
// The shape is normalized to [N] to become the bias of FC | ||
fused_bias->rank(1); | ||
fused_bias->dim(0) = weights->dim(0); | ||
fused_bias->dim(0) = nodes.weights->dim(0); | ||
|
||
fc->bias(fused_bias); | ||
fc->fusedActivationFunction(add->fusedActivationFunction()); | ||
nodes.fc->bias(fused_bias); | ||
nodes.fc->fusedActivationFunction(add->fusedActivationFunction()); | ||
|
||
// set origin | ||
luci::add_origin(fc, luci::get_origin(add)); | ||
luci::add_origin(nodes.fc, luci::get_origin(add)); | ||
|
||
replace(add).with(fc); | ||
replace(add).with(nodes.fc); | ||
|
||
return true; | ||
} | ||
|
@@ -144,22 +147,17 @@ luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len) | |
|
||
bool fuse_add_with_s16_fc(luci::CircleAdd *add) | ||
{ | ||
luci::CircleFullyConnected *fc = nullptr; | ||
// addition must be const | ||
luci::CircleConst *addition = nullptr; | ||
luci::CircleConst *weights = nullptr; | ||
luci::CircleNode *bias = nullptr; | ||
PatternNodes nodes; | ||
|
||
RETURN_FALSE_UNLESS( | ||
fc_with_add_pattern_check(loco::DataType::S16, &add, &fc, &addition, &weights, &bias)); | ||
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::S16, *add, nodes)); | ||
|
||
// If bias is const, its dtype must be s64 | ||
RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST and | ||
bias->dtype() == loco::DataType::S64); | ||
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST and | ||
nodes.bias->dtype() == loco::DataType::S64); | ||
|
||
const auto last_dim = addition->dim(addition->rank() - 1).value(); | ||
const auto last_dim = nodes.addition->dim(nodes.addition->rank() - 1).value(); | ||
|
||
const auto addition_qparam = get_qparam(addition, last_dim); | ||
const auto addition_qparam = get_qparam(nodes.addition, last_dim); | ||
RETURN_FALSE_UNLESS(addition_qparam); | ||
|
||
std::vector<float> fp32_bias(last_dim); | ||
|
@@ -168,12 +166,12 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) | |
auto scale = addition_qparam->scale.at(i); | ||
RETURN_FALSE_UNLESS(addition_qparam->zerop.at(i) == 0); | ||
|
||
auto val = addition->at<loco::DataType::S16>(i); | ||
auto val = nodes.addition->at<loco::DataType::S16>(i); | ||
fp32_bias[i] = val * scale; | ||
} | ||
|
||
// Add existing bias values | ||
if (auto const_bias = dynamic_cast<luci::CircleConst *>(bias)) | ||
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.bias)) | ||
{ | ||
const auto bias_qparam = get_qparam(const_bias, last_dim); | ||
RETURN_FALSE_UNLESS(bias_qparam); | ||
|
@@ -191,14 +189,14 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) | |
const auto add_qparam = get_qparam(add, 1); | ||
RETURN_FALSE_UNLESS(add_qparam); | ||
|
||
auto input = loco::must_cast<luci::CircleNode *>(fc->input()); | ||
auto input = loco::must_cast<luci::CircleNode *>(nodes.fc->input()); | ||
const auto input_qparam = get_qparam(input, 1); | ||
RETURN_FALSE_UNLESS(input_qparam); | ||
|
||
const auto weights_qparam = get_qparam(weights, last_dim); | ||
const auto weights_qparam = get_qparam(nodes.weights, last_dim); | ||
RETURN_FALSE_UNLESS(weights_qparam); | ||
|
||
auto fused_bias = luci::clone(addition); | ||
auto fused_bias = luci::clone(nodes.addition); | ||
fused_bias->dtype(loco::DataType::S64); | ||
fused_bias->size<loco::DataType::S64>(last_dim); | ||
|
||
|
@@ -231,21 +229,21 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add) | |
fused_bias->quantparam(std::move(bias_qparam)); | ||
|
||
// In-place update. This works because fc is guaranteed to have a single successor | ||
fc->bias(fused_bias); | ||
fc->fusedActivationFunction(add->fusedActivationFunction()); | ||
nodes.fc->bias(fused_bias); | ||
nodes.fc->fusedActivationFunction(add->fusedActivationFunction()); | ||
|
||
auto qparam = std::make_unique<luci::CircleQuantParam>(); | ||
{ | ||
qparam->scale.push_back(add_qparam->scale.at(0)); | ||
qparam->zerop.push_back(add_qparam->scale.at(0)); | ||
} | ||
|
||
fc->quantparam(std::move(qparam)); | ||
nodes.fc->quantparam(std::move(qparam)); | ||
|
||
// set origin | ||
luci::add_origin(fc, luci::get_origin(add)); | ||
luci::add_origin(nodes.fc, luci::get_origin(add)); | ||
|
||
replace(add).with(fc); | ||
replace(add).with(nodes.fc); | ||
|
||
return true; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q) Why are you using
const luci::CircleAdd &
?