Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 50 additions & 52 deletions compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,48 @@ namespace
if (not(cond)) \
return false;

bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add,
luci::CircleFullyConnected **fc, luci::CircleConst **addition,
luci::CircleConst **weights, luci::CircleNode **bias)
struct PatternNodes
{
RETURN_FALSE_UNLESS((*add)->dtype() == dtype);
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
};

RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add));
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) Why are you using const luci::CircleAdd & ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing loco::DataType,

template <typename DT>
bool fc_with_add_pattern_check(const luci::CircleAdd &add,

?

PatternNodes &nodes)
{
RETURN_FALSE_UNLESS(add.dtype() == dtype);

RETURN_FALSE_UNLESS(luci::fill(&nodes.fc, &nodes.addition).with_commutative_args_of(&add));

// Check if fc has only one successor to limit possible weights size increase.
RETURN_FALSE_UNLESS(loco::succs(*fc).size() == 1);
RETURN_FALSE_UNLESS((*fc)->dtype() == dtype);
RETURN_FALSE_UNLESS((*fc)->fusedActivationFunction() == luci::FusedActFunc::NONE);
RETURN_FALSE_UNLESS(loco::succs(nodes.fc).size() == 1);
RETURN_FALSE_UNLESS(nodes.fc->dtype() == dtype);
RETURN_FALSE_UNLESS(nodes.fc->fusedActivationFunction() == luci::FusedActFunc::NONE);

*weights = dynamic_cast<luci::CircleConst *>((*fc)->weights());
RETURN_FALSE_UNLESS(*weights);
nodes.weights = dynamic_cast<luci::CircleConst *>(nodes.fc->weights());
RETURN_FALSE_UNLESS(nodes.weights);

RETURN_FALSE_UNLESS((*addition)->dtype() == dtype);
RETURN_FALSE_UNLESS(nodes.addition->dtype() == dtype);

auto rank = (*addition)->rank();
auto rank = (nodes.addition)->rank();
// TODO Support scalar addition
RETURN_FALSE_UNLESS(rank != 0);

for (uint32_t i = 0; i < rank - 1; i++)
{
RETURN_FALSE_UNLESS((*addition)->dim(i).value() == 1);
RETURN_FALSE_UNLESS(nodes.addition->dim(i).value() == 1);
}
// Check the last dimesion of addition is the same with the number of neurons of FC
RETURN_FALSE_UNLESS((*addition)->dim(rank - 1) == (*weights)->dim(0));
RETURN_FALSE_UNLESS(nodes.addition->dim(rank - 1) == nodes.weights->dim(0));

// We only support (1) constant bias (2) no bias
// If bias is neither (1) nor (2), it would be a feature map
*bias = loco::must_cast<luci::CircleNode *>((*fc)->bias());
RETURN_FALSE_UNLESS((*bias)->opcode() == luci::CircleOpcode::CIRCLECONST or
(*bias)->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE);
nodes.bias = loco::must_cast<luci::CircleNode *>(nodes.fc->bias());
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST or
nodes.bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE);

return true;
}
Expand All @@ -87,19 +95,14 @@ bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add
*/
bool fuse_add_with_fc(luci::CircleAdd *add)
{
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
PatternNodes nodes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems you want to find members of this nodes.
there are similar styles in other Pass that you can refer.
plz check compiler/luci/pass/src/FuseRmsNormPass.cpp that was recently added.


RETURN_FALSE_UNLESS(
fc_with_add_pattern_check(loco::DataType::FLOAT32, &add, &fc, &addition, &weights, &bias));
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::FLOAT32, *add, nodes));

auto fused_bias = luci::clone(addition);
auto fused_bias = luci::clone(nodes.addition);

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.fc->bias()))
{
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check this inside common_pass_checks? luci::clone adds a node into the graph (graph is modified), so all checks had to be done before luci::clone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this check is specific to float32.. It would be better to use a pattern class (PTAL how FuseGeluPass finds different patterns).

Copy link
Contributor Author

@jiwaszki jiwaszki Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinevening could this refactor be done in separate PR? In my opinion, it would be better to push it as-is and create another PR that will introduce pattern search like FuseGeluPass.


Expand All @@ -113,15 +116,15 @@ bool fuse_add_with_fc(luci::CircleAdd *add)
// where N is weights->dim(0).
// The shape is normalized to [N] to become the bias of FC
fused_bias->rank(1);
fused_bias->dim(0) = weights->dim(0);
fused_bias->dim(0) = nodes.weights->dim(0);

fc->bias(fused_bias);
fc->fusedActivationFunction(add->fusedActivationFunction());
nodes.fc->bias(fused_bias);
nodes.fc->fusedActivationFunction(add->fusedActivationFunction());

// set origin
luci::add_origin(fc, luci::get_origin(add));
luci::add_origin(nodes.fc, luci::get_origin(add));

replace(add).with(fc);
replace(add).with(nodes.fc);

return true;
}
Expand All @@ -144,22 +147,17 @@ luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len)

bool fuse_add_with_s16_fc(luci::CircleAdd *add)
{
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
PatternNodes nodes;

RETURN_FALSE_UNLESS(
fc_with_add_pattern_check(loco::DataType::S16, &add, &fc, &addition, &weights, &bias));
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::S16, *add, nodes));

// If bias is const, its dtype must be s64
RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST and
bias->dtype() == loco::DataType::S64);
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST and
nodes.bias->dtype() == loco::DataType::S64);

const auto last_dim = addition->dim(addition->rank() - 1).value();
const auto last_dim = nodes.addition->dim(nodes.addition->rank() - 1).value();

const auto addition_qparam = get_qparam(addition, last_dim);
const auto addition_qparam = get_qparam(nodes.addition, last_dim);
RETURN_FALSE_UNLESS(addition_qparam);

std::vector<float> fp32_bias(last_dim);
Expand All @@ -168,12 +166,12 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
auto scale = addition_qparam->scale.at(i);
RETURN_FALSE_UNLESS(addition_qparam->zerop.at(i) == 0);

auto val = addition->at<loco::DataType::S16>(i);
auto val = nodes.addition->at<loco::DataType::S16>(i);
fp32_bias[i] = val * scale;
}

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(bias))
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.bias))
{
const auto bias_qparam = get_qparam(const_bias, last_dim);
RETURN_FALSE_UNLESS(bias_qparam);
Expand All @@ -191,14 +189,14 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
const auto add_qparam = get_qparam(add, 1);
RETURN_FALSE_UNLESS(add_qparam);

auto input = loco::must_cast<luci::CircleNode *>(fc->input());
auto input = loco::must_cast<luci::CircleNode *>(nodes.fc->input());
const auto input_qparam = get_qparam(input, 1);
RETURN_FALSE_UNLESS(input_qparam);

const auto weights_qparam = get_qparam(weights, last_dim);
const auto weights_qparam = get_qparam(nodes.weights, last_dim);
RETURN_FALSE_UNLESS(weights_qparam);

auto fused_bias = luci::clone(addition);
auto fused_bias = luci::clone(nodes.addition);
fused_bias->dtype(loco::DataType::S64);
fused_bias->size<loco::DataType::S64>(last_dim);

Expand Down Expand Up @@ -231,21 +229,21 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
fused_bias->quantparam(std::move(bias_qparam));

// In-place update. This works because fc is guaranteed to have a single successor
fc->bias(fused_bias);
fc->fusedActivationFunction(add->fusedActivationFunction());
nodes.fc->bias(fused_bias);
nodes.fc->fusedActivationFunction(add->fusedActivationFunction());

auto qparam = std::make_unique<luci::CircleQuantParam>();
{
qparam->scale.push_back(add_qparam->scale.at(0));
qparam->zerop.push_back(add_qparam->scale.at(0));
}

fc->quantparam(std::move(qparam));
nodes.fc->quantparam(std::move(qparam));

// set origin
luci::add_origin(fc, luci::get_origin(add));
luci::add_origin(nodes.fc, luci::get_origin(add));

replace(add).with(fc);
replace(add).with(nodes.fc);

return true;
}
Expand Down