Skip to content

Commit

Permalink
Use struct to handle pointers to nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwaszki committed Sep 25, 2024
1 parent abf1acf commit d8e1245
Showing 1 changed file with 50 additions and 52 deletions.
102 changes: 50 additions & 52 deletions compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,48 @@ namespace
if (not(cond)) \
return false;

bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add,
luci::CircleFullyConnected **fc, luci::CircleConst **addition,
luci::CircleConst **weights, luci::CircleNode **bias)
struct PatternNodes
{
RETURN_FALSE_UNLESS((*add)->dtype() == dtype);
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
};

RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add));
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add,
PatternNodes &nodes)
{
RETURN_FALSE_UNLESS(add.dtype() == dtype);

RETURN_FALSE_UNLESS(luci::fill(&nodes.fc, &nodes.addition).with_commutative_args_of(&add));

// Check if fc has only one successor to limit possible weights size increase.
RETURN_FALSE_UNLESS(loco::succs(*fc).size() == 1);
RETURN_FALSE_UNLESS((*fc)->dtype() == dtype);
RETURN_FALSE_UNLESS((*fc)->fusedActivationFunction() == luci::FusedActFunc::NONE);
RETURN_FALSE_UNLESS(loco::succs(nodes.fc).size() == 1);
RETURN_FALSE_UNLESS(nodes.fc->dtype() == dtype);
RETURN_FALSE_UNLESS(nodes.fc->fusedActivationFunction() == luci::FusedActFunc::NONE);

*weights = dynamic_cast<luci::CircleConst *>((*fc)->weights());
RETURN_FALSE_UNLESS(*weights);
nodes.weights = dynamic_cast<luci::CircleConst *>(nodes.fc->weights());
RETURN_FALSE_UNLESS(nodes.weights);

RETURN_FALSE_UNLESS((*addition)->dtype() == dtype);
RETURN_FALSE_UNLESS(nodes.addition->dtype() == dtype);

auto rank = (*addition)->rank();
auto rank = (nodes.addition)->rank();
// TODO Support scalar addition
RETURN_FALSE_UNLESS(rank != 0);

for (uint32_t i = 0; i < rank - 1; i++)
{
RETURN_FALSE_UNLESS((*addition)->dim(i).value() == 1);
RETURN_FALSE_UNLESS(nodes.addition->dim(i).value() == 1);
}
// Check the last dimesion of addition is the same with the number of neurons of FC
RETURN_FALSE_UNLESS((*addition)->dim(rank - 1) == (*weights)->dim(0));
RETURN_FALSE_UNLESS(nodes.addition->dim(rank - 1) == nodes.weights->dim(0));

// We only support (1) constant bias (2) no bias
// If bias is neither (1) nor (2), it would be a feature map
*bias = loco::must_cast<luci::CircleNode *>((*fc)->bias());
RETURN_FALSE_UNLESS((*bias)->opcode() == luci::CircleOpcode::CIRCLECONST or
(*bias)->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE);
nodes.bias = loco::must_cast<luci::CircleNode *>(nodes.fc->bias());
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST or
nodes.bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE);

return true;
}
Expand All @@ -87,19 +95,14 @@ bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add
*/
bool fuse_add_with_fc(luci::CircleAdd *add)
{
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
PatternNodes nodes;

RETURN_FALSE_UNLESS(
fc_with_add_pattern_check(loco::DataType::FLOAT32, &add, &fc, &addition, &weights, &bias));
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::FLOAT32, *add, nodes));

auto fused_bias = luci::clone(addition);
auto fused_bias = luci::clone(nodes.addition);

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.fc->bias()))
{
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);

Expand All @@ -113,15 +116,15 @@ bool fuse_add_with_fc(luci::CircleAdd *add)
// where N is weights->dim(0).
// The shape is normalized to [N] to become the bias of FC
fused_bias->rank(1);
fused_bias->dim(0) = weights->dim(0);
fused_bias->dim(0) = nodes.weights->dim(0);

fc->bias(fused_bias);
fc->fusedActivationFunction(add->fusedActivationFunction());
nodes.fc->bias(fused_bias);
nodes.fc->fusedActivationFunction(add->fusedActivationFunction());

// set origin
luci::add_origin(fc, luci::get_origin(add));
luci::add_origin(nodes.fc, luci::get_origin(add));

replace(add).with(fc);
replace(add).with(nodes.fc);

return true;
}
Expand All @@ -144,22 +147,17 @@ luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len)

bool fuse_add_with_s16_fc(luci::CircleAdd *add)
{
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;
PatternNodes nodes;

RETURN_FALSE_UNLESS(
fc_with_add_pattern_check(loco::DataType::S16, &add, &fc, &addition, &weights, &bias));
RETURN_FALSE_UNLESS(fc_with_add_pattern_check(loco::DataType::S16, *add, nodes));

// If bias is const, its dtype must be s64
RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST and
bias->dtype() == loco::DataType::S64);
RETURN_FALSE_UNLESS(nodes.bias->opcode() == luci::CircleOpcode::CIRCLECONST and
nodes.bias->dtype() == loco::DataType::S64);

const auto last_dim = addition->dim(addition->rank() - 1).value();
const auto last_dim = nodes.addition->dim(nodes.addition->rank() - 1).value();

const auto addition_qparam = get_qparam(addition, last_dim);
const auto addition_qparam = get_qparam(nodes.addition, last_dim);
RETURN_FALSE_UNLESS(addition_qparam);

std::vector<float> fp32_bias(last_dim);
Expand All @@ -168,12 +166,12 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
auto scale = addition_qparam->scale.at(i);
RETURN_FALSE_UNLESS(addition_qparam->zerop.at(i) == 0);

auto val = addition->at<loco::DataType::S16>(i);
auto val = nodes.addition->at<loco::DataType::S16>(i);
fp32_bias[i] = val * scale;
}

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(bias))
if (auto const_bias = dynamic_cast<luci::CircleConst *>(nodes.bias))
{
const auto bias_qparam = get_qparam(const_bias, last_dim);
RETURN_FALSE_UNLESS(bias_qparam);
Expand All @@ -191,14 +189,14 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
const auto add_qparam = get_qparam(add, 1);
RETURN_FALSE_UNLESS(add_qparam);

auto input = loco::must_cast<luci::CircleNode *>(fc->input());
auto input = loco::must_cast<luci::CircleNode *>(nodes.fc->input());
const auto input_qparam = get_qparam(input, 1);
RETURN_FALSE_UNLESS(input_qparam);

const auto weights_qparam = get_qparam(weights, last_dim);
const auto weights_qparam = get_qparam(nodes.weights, last_dim);
RETURN_FALSE_UNLESS(weights_qparam);

auto fused_bias = luci::clone(addition);
auto fused_bias = luci::clone(nodes.addition);
fused_bias->dtype(loco::DataType::S64);
fused_bias->size<loco::DataType::S64>(last_dim);

Expand Down Expand Up @@ -231,21 +229,21 @@ bool fuse_add_with_s16_fc(luci::CircleAdd *add)
fused_bias->quantparam(std::move(bias_qparam));

// In-place update. This works because fc is guaranteed to have a single successor
fc->bias(fused_bias);
fc->fusedActivationFunction(add->fusedActivationFunction());
nodes.fc->bias(fused_bias);
nodes.fc->fusedActivationFunction(add->fusedActivationFunction());

auto qparam = std::make_unique<luci::CircleQuantParam>();
{
qparam->scale.push_back(add_qparam->scale.at(0));
qparam->zerop.push_back(add_qparam->scale.at(0));
}

fc->quantparam(std::move(qparam));
nodes.fc->quantparam(std::move(qparam));

// set origin
luci::add_origin(fc, luci::get_origin(add));
luci::add_origin(nodes.fc, luci::get_origin(add));

replace(add).with(fc);
replace(add).with(nodes.fc);

return true;
}
Expand Down

0 comments on commit d8e1245

Please sign in to comment.