Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Refactor FuseAddWithTConvPass #13745

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 41 additions & 52 deletions compiler/luci/pass/src/FuseAddWithTConvPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

/**
* Fuse Add to TransposeConv if possible
*
Expand All @@ -42,89 +49,74 @@ namespace
*
* Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
bool fuse_add_with_tconv(luci::CircleAdd *add)
{
// skip if tconv has fused activation
if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;
// check whether it has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
if (not bias)
return false;

// get weight of tconv
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
if (not filter)
return false;
if (filter->dtype() != loco::DataType::FLOAT32)
return false;

// get add node
auto tconv_output = loco::succs(tconv);
assert(tconv_output.size() == 1);
auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
if (not add)
return false;
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;

// get addition
// Allow Add node only with FLOAT32 data type.
RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32);
// Allow Add node only with specific activations.
RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU6 ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU);
// Find the pattern of Add(TransposeConv, CircleConst):
luci::CircleTransposeConv *tconv = nullptr;
luci::CircleConst *addition = nullptr;
if (add->x() == tconv)
addition = dynamic_cast<luci::CircleConst *>(add->y());
else
addition = dynamic_cast<luci::CircleConst *>(add->x());
RETURN_FALSE_UNLESS(luci::fill(&tconv, &addition).with_commutative_args_of(add));

if (not addition)
return false;
RETURN_FALSE_UNLESS(loco::succs(tconv).size() == 1);

// Skip if tconv has fused activation.
RETURN_FALSE_UNLESS(tconv->fusedActivationFunction() == luci::FusedActFunc::NONE);
// Check whether tconv has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
RETURN_FALSE_UNLESS(bias);
// Get weights of tconv:
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
RETURN_FALSE_UNLESS(filter);
RETURN_FALSE_UNLESS(filter->dtype() == loco::DataType::FLOAT32);

// addition dim(0) == tconv filter channel dim
if (addition->rank() != 1)
return false;
RETURN_FALSE_UNLESS(addition->rank() == 1);

auto addition_dim = addition->dim(0).value();
auto filter_channel_dim = filter->dim(0).value();
if (filter_channel_dim != addition_dim)
return false;
RETURN_FALSE_UNLESS(filter_channel_dim == addition_dim);

// fuse addition with transposed conv
// Fuse addition with transposed conv:
seanshpark marked this conversation as resolved.
Show resolved Hide resolved
tconv->bias(addition);

if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
{
auto name = addition->name();
assert(name.length() > 0);
// separate relu op from add op
// Separate relu op from add op:
auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
relu->features(tconv);
relu->name(name + "/Relu6");
luci::add_origin(relu, luci::get_origin(add));

// remove add node
// Remove add node.
replace(add).with(relu);
}
else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
{
auto name = addition->name();
assert(name.length() > 0);
// separate relu op from add op
// Separate relu op from add op:
auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
relu->features(tconv);
relu->name(name + "/Relu");
luci::add_origin(relu, luci::get_origin(add));

// remove add node
// Remove add node.
replace(add).with(relu);
}
else
{
// Remove add node.
replace(add).with(tconv);
}

// set origin
// Set new origin.
luci::add_origin(tconv, luci::get_origin(add));

return true;
Expand All @@ -140,12 +132,9 @@ bool FuseAddWithTConvPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
if (not tconv)
continue;

if (fuse_add_with_tconv(tconv))
changed = true;
if (auto add = dynamic_cast<luci::CircleAdd *>(node))
if (fuse_add_with_tconv(add))
changed = true;
}

return changed;
Expand Down
189 changes: 188 additions & 1 deletion compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,196 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/CreateCircleConst.h"

#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>

#include <gtest/gtest.h>

TEST(FuseAddWithTConvPassTest, name)
#define ADD_VAL 5.0f
namespace
{

using namespace luci::test;

/**
* Graph for this test
*
* BEFORE (without extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ |
* [CircleAdd w/ Relu]
* |
*
* BEFORE (with extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ | |
* [CircleAdd w/ Relu] [extra FC]
* | |
*
* AFTER (if pass was successful)
*
* |
* [CircleConst as bias] |
* \ |
* [CircleTransposeConv]
* |
* ([CircleRelu/Relu])
* |
*
*/
class TConvAddGraphlet
{
public:
void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias,
bool extra_successor)
{
_tconv = g->nodes()->create<luci::CircleTransposeConv>();

std::vector<float> input_sizes_val = {1, 4, 4, 1};
_tconv_i = luci::create_const_node(g, loco::DataType::FLOAT32, {4}, input_sizes_val);
_tconv->inputSizes(_tconv_i);

std::vector<float> filter_val(18);
for (uint32_t i = 0; i < 18; i++)
filter_val.at(i) = i;

_tconv_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 3, 3, 2}, filter_val);
_tconv->filter(_tconv_f);

if (use_bias)
{
std::vector<float> bias_val(1, 3.0f);
_tconv_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, bias_val);
}
else
{
// Create CircleOutputExclude -- no bias
_tconv_b = g->nodes()->create<luci::CircleOutputExclude>();
}
_tconv->bias(_tconv_b);

_tconv->padding(luci::Padding::VALID);
auto _stride = _tconv->stride();
_stride->w(1);
_stride->h(1);
_tconv->fusedActivationFunction(tconv_activation);
_tconv->dtype(loco::DataType::FLOAT32);
_tconv->shape({1, 4, 4, 1});
_tconv->name("tconv");

if (extra_successor)
{
_extra_succ = g->nodes()->create<luci::CircleFullyConnected>();
// Set previous TConv as input to bump number of successors for it:
_extra_succ->input(_tconv);
std::vector<float> weights_val(8);
_extra_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 8}, weights_val);
_extra_succ->weights(_extra_f);
_extra_succ->bias(nullptr);
_extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE);
_extra_succ->dtype(loco::DataType::FLOAT32);
_extra_succ->shape({1, 4, 4, 1});
_extra_succ->name("extra_fc");
}

std::vector<float> add_values(1, ADD_VAL);
_add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, add_values);
_add_c->name("const_c");

_add = g->nodes()->create<luci::CircleAdd>();
_add->x(_tconv);
_add->y(_add_c);
_add->fusedActivationFunction(luci::FusedActFunc::RELU);
_add->dtype(loco::DataType::FLOAT32);
_add->shape({1, 4, 4, 1});

_add->name("add");
}

protected:
luci::CircleTransposeConv *_tconv = nullptr;
luci::CircleConst *_tconv_i = nullptr;
luci::CircleConst *_tconv_f = nullptr;
luci::CircleNode *_tconv_b = nullptr;
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_c = nullptr;
luci::CircleFullyConnected *_extra_succ = nullptr;
luci::CircleConst *_extra_f = nullptr;
};

class FuseAddWithTConvTestGraph : public TestIOGraph, public TConvAddGraphlet
{
public:
void init(luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor)
{
TestIOGraph::init({1, 2, 2, 2}, {1, 4, 4, 1});
TConvAddGraphlet::init(g(), tconv_activation, use_bias, extra_successor);

_tconv->outBackprop(input());

output()->from(_add);
}
};

class FuseAddWithTConvPassTest : public ::testing::Test
{
public:
FuseAddWithTConvTestGraph g;
luci::FuseAddWithTConvPass pass;
};

} // namespace

TEST_F(FuseAddWithTConvPassTest, tconv_add_fuse)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

auto relu = dynamic_cast<luci::CircleRelu *>(g.output()->from());
EXPECT_NE(nullptr, relu);
EXPECT_STREQ(relu->name().c_str(), "const_c/Relu");

auto tconv = dynamic_cast<luci::CircleTransposeConv *>(relu->features());
EXPECT_NE(nullptr, tconv);

auto bias = loco::must_cast<luci::CircleConst *>(tconv->bias());
EXPECT_NE(nullptr, bias);

for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
EXPECT_EQ(ADD_VAL, bias->at<loco::DataType::FLOAT32>(i));
}
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_bias_NEG)
{
g.init(luci::FusedActFunc::NONE, true /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_activation_NEG)
{
g.init(luci::FusedActFunc::RELU, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_extra_successor_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, true /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, name)
{
luci::FuseAddWithTConvPass pass;
auto const name = pass.name();
Expand Down