Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 79 additions & 133 deletions compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Pass/FuseAddWithFullyConnectedPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>
Expand All @@ -24,6 +26,49 @@

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

inline bool common_pass_checks(const loco::DataType dtype, luci::CircleAdd **add,
jinevening marked this conversation as resolved.
Show resolved Hide resolved
luci::CircleFullyConnected **fc, luci::CircleConst **addition,
luci::CircleConst **weights, luci::CircleNode **bias)
{
RETURN_FALSE_UNLESS((*add)->dtype() == dtype);

RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add));

// Check if fc has only one successor to limit possible weights size increase.
RETURN_FALSE_UNLESS(loco::succs(*fc).size() == 1);
RETURN_FALSE_UNLESS((*fc)->dtype() == dtype);
RETURN_FALSE_UNLESS((*fc)->fusedActivationFunction() == luci::FusedActFunc::NONE);

*weights = dynamic_cast<luci::CircleConst *>((*fc)->weights());
RETURN_FALSE_UNLESS(*weights);

RETURN_FALSE_UNLESS((*addition)->dtype() == dtype);

auto rank = (*addition)->rank();
// TODO Support scalar addition
RETURN_FALSE_UNLESS(rank != 0);

for (uint32_t i = 0; i < rank - 1; i++)
{
RETURN_FALSE_UNLESS((*addition)->dim(i).value() == 1);
}
// Check the last dimesion of addition is the same with the number of neurons of FC
RETURN_FALSE_UNLESS((*addition)->dim(rank - 1) == (*weights)->dim(0));

// We only support (1) constant bias (2) no bias
// If bias is neither (1) nor (2), it would be a feature map
*bias = loco::must_cast<luci::CircleNode *>((*fc)->bias());
RETURN_FALSE_UNLESS((*bias)->opcode() == luci::CircleOpcode::CIRCLECONST or
(*bias)->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE);

return true;
}

/**
* Fuse Add to FullyConnected if the added value is a channel(last dimension)-wise constant
*
Expand All @@ -40,71 +85,26 @@ namespace
* |
*
*/
bool fuse_add_with_fc(luci::CircleFullyConnected *fc)
bool fuse_add_with_fc(luci::CircleAdd *add)
{
if (not fc)
return false;
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;

if (fc->dtype() != loco::DataType::FLOAT32)
return false;

if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;

auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
if (not weights)
return false;

// Get add node
auto fc_output = loco::succs(fc);
if (fc_output.size() != 1)
return false;

auto add = dynamic_cast<luci::CircleAdd *>(*fc_output.begin());
if (not add)
return false;
if (add->dtype() != loco::DataType::FLOAT32)
return false;

// Get addition
auto addition = add->x() == fc ? dynamic_cast<luci::CircleConst *>(add->y())
: dynamic_cast<luci::CircleConst *>(add->x());

// Non-const addition
if (not addition)
return false;

auto rank = addition->rank();
// TODO Support scalar addition
if (rank == 0)
return false;

for (uint32_t i = 0; i < rank - 1; i++)
{
if (addition->dim(i).value() != 1)
return false;
}
// Check the last dimesion of addition is the same with the number of neurons of FC
if (not(addition->dim(rank - 1) == weights->dim(0)))
return false;

auto bias = loco::must_cast<luci::CircleNode *>(fc->bias());

// We only support (1) constant bias (2) no bias
// If bias is neither (1) nor (2), it would be a feature map
if (bias->opcode() != luci::CircleOpcode::CIRCLECONST and
bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
return false;
RETURN_FALSE_UNLESS(
common_pass_checks(loco::DataType::FLOAT32, &add, &fc, &addition, &weights, &bias));

auto fused_bias = luci::clone(addition);

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
{
assert(const_bias->dtype() == loco::DataType::FLOAT32);
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check this inside common_pass_checks? luci::clone adds a node into the graph (graph is modified), so all checks had to be done before luci::clone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this check is specific to float32.. It would be better to use a pattern class (PTAL how FuseGeluPass finds different patterns).

Copy link
Contributor Author

@jiwaszki jiwaszki Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinevening could this refactor be done in separate PR? In my opinion, it would be better to push it as-is and create another PR that will introduce pattern search like FuseGeluPass.


auto bias_size = fused_bias->size<loco::DataType::FLOAT32>();
assert(bias_size == const_bias->size<loco::DataType::FLOAT32>());
RETURN_FALSE_UNLESS(bias_size == const_bias->size<loco::DataType::FLOAT32>());
for (uint32_t i = 0; i < bias_size; i++)
fused_bias->at<loco::DataType::FLOAT32>(i) += const_bias->at<loco::DataType::FLOAT32>(i);
}
Expand Down Expand Up @@ -142,80 +142,31 @@ luci::CircleQuantParam *get_qparam(luci::CircleNode *node, uint32_t len)
return node->quantparam();
}

bool fuse_add_with_s16_fc(luci::CircleFullyConnected *fc)
bool fuse_add_with_s16_fc(luci::CircleAdd *add)
{
assert(fc); // FIX_CALLER_UNLESS
assert(fc->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS

if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;
luci::CircleFullyConnected *fc = nullptr;
// addition must be const
luci::CircleConst *addition = nullptr;
luci::CircleConst *weights = nullptr;
luci::CircleNode *bias = nullptr;

auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
if (not weights)
return false;

auto fc_output = loco::succs(fc);
// Fuse only when FC has a single successor (to avoid weight increase)
if (fc_output.size() != 1)
return false;

auto add = dynamic_cast<luci::CircleAdd *>(*fc_output.begin());
if (not add)
return false;

// Only support the same dtype with fc
if (add->dtype() != loco::DataType::S16)
return false;

// Get addition
auto addition = add->x() == fc ? dynamic_cast<luci::CircleConst *>(add->y())
: dynamic_cast<luci::CircleConst *>(add->x());

// Non-const addition
if (not addition)
return false;

// Check addition dtype
if (addition->dtype() != loco::DataType::S16)
return false;

auto rank = addition->rank();
// TODO Support scalar addition
if (rank == 0)
return false;

for (uint32_t i = 0; i < rank - 1; i++)
{
if (addition->dim(i).value() != 1)
return false;
}

// Check the last dim of addition is the same with the output dim of weight
const auto last_dim = addition->dim(rank - 1).value();
if (last_dim != weights->dim(0).value())
return false;

auto bias = loco::must_cast<luci::CircleNode *>(fc->bias());

// Only support (1) constant bias, or (2) no bias
if (bias->opcode() != luci::CircleOpcode::CIRCLECONST and
bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
return false;
RETURN_FALSE_UNLESS(
common_pass_checks(loco::DataType::S16, &add, &fc, &addition, &weights, &bias));

// If bias is const, its dtype must be s64
if (bias->opcode() == luci::CircleOpcode::CIRCLECONST and bias->dtype() != loco::DataType::S64)
return false;
RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST and
bias->dtype() == loco::DataType::S64);

const auto last_dim = addition->dim(addition->rank() - 1).value();

const auto addition_qparam = get_qparam(addition, last_dim);
if (addition_qparam == nullptr)
return false;
RETURN_FALSE_UNLESS(addition_qparam);

std::vector<float> fp32_bias(last_dim);
for (uint32_t i = 0; i < last_dim; i++)
{
auto scale = addition_qparam->scale.at(i);
if (addition_qparam->zerop.at(i) != 0)
return false; // FIX_ME_UNLESS
RETURN_FALSE_UNLESS(addition_qparam->zerop.at(i) == 0);

auto val = addition->at<loco::DataType::S16>(i);
fp32_bias[i] = val * scale;
Expand All @@ -225,32 +176,27 @@ bool fuse_add_with_s16_fc(luci::CircleFullyConnected *fc)
if (auto const_bias = dynamic_cast<luci::CircleConst *>(bias))
{
const auto bias_qparam = get_qparam(const_bias, last_dim);
if (bias_qparam == nullptr)
return false;
RETURN_FALSE_UNLESS(bias_qparam);

for (uint32_t i = 0; i < last_dim; i++)
{
auto scale = bias_qparam->scale.at(i);
if (bias_qparam->zerop.at(i) != 0)
return false; // FIX_ME_UNLESS
RETURN_FALSE_UNLESS(bias_qparam->zerop.at(i) == 0);

auto val = const_bias->at<loco::DataType::S64>(i);
fp32_bias[i] += val * scale;
}
}

const auto add_qparam = get_qparam(add, 1);
if (add_qparam == nullptr)
return false;
RETURN_FALSE_UNLESS(add_qparam);

auto input = loco::must_cast<luci::CircleNode *>(fc->input());
const auto input_qparam = get_qparam(input, 1);
if (input_qparam == nullptr)
return false;
RETURN_FALSE_UNLESS(input_qparam);

const auto weights_qparam = get_qparam(weights, last_dim);
if (weights_qparam == nullptr)
return false;
RETURN_FALSE_UNLESS(weights_qparam);

auto fused_bias = luci::clone(addition);
fused_bias->dtype(loco::DataType::S64);
Expand Down Expand Up @@ -314,18 +260,18 @@ bool FuseAddWithFullyConnectedPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
if (not fc)
auto add = dynamic_cast<luci::CircleAdd *>(node);
if (not add)
continue;

switch (fc->dtype())
switch (add->dtype())
{
case loco::DataType::FLOAT32:
if (fuse_add_with_fc(fc))
if (fuse_add_with_fc(add))
changed = true;
break;
case loco::DataType::S16:
if (fuse_add_with_s16_fc(fc))
if (fuse_add_with_s16_fc(add))
changed = true;
break;
default:
Expand Down