Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Refactor FuseAddWithTConvPass #13745

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 36 additions & 48 deletions compiler/luci/pass/src/FuseAddWithTConvPass.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright (c) 2020-2024 Samsung Electronics Co., Ltd. All Rights Reserved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to do this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about it as well... now I have guidance to follow. Will do!

*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,11 +16,18 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

/**
* Fuse Add to TransposeConv if possible
*
Expand All @@ -42,55 +49,39 @@ namespace
*
* Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
bool fuse_add_with_tconv(luci::CircleAdd *add)
{
// skip if tconv has fused activation
if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;
// check whether it has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
if (not bias)
return false;

// get weight of tconv
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
if (not filter)
return false;
if (filter->dtype() != loco::DataType::FLOAT32)
return false;

// get add node
auto tconv_output = loco::succs(tconv);
assert(tconv_output.size() == 1);
auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
if (not add)
return false;
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;

// get addition
// Allow Add node only with FLOAT32 data type:
RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32);

RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU6 ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU);
// Find the pattern of Add(TransposeConv, CircleConst):
luci::CircleTransposeConv *tconv = nullptr;
luci::CircleConst *addition = nullptr;
if (add->x() == tconv)
addition = dynamic_cast<luci::CircleConst *>(add->y());
else
addition = dynamic_cast<luci::CircleConst *>(add->x());
RETURN_FALSE_UNLESS(luci::fill(&tconv, &addition).with_commutative_args_of(add));

if (not addition)
return false;
RETURN_FALSE_UNLESS(loco::succs(tconv).size() == 1);

// Skip if tconv has fused activation.
RETURN_FALSE_UNLESS(tconv->fusedActivationFunction() == luci::FusedActFunc::NONE);
// Check whether tconv has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
RETURN_FALSE_UNLESS(bias);
// Get weights of tconv:
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
RETURN_FALSE_UNLESS(filter);
RETURN_FALSE_UNLESS(filter->dtype() == loco::DataType::FLOAT32);

// addition dim(0) == tconv filter channel dim
if (addition->rank() != 1)
return false;
RETURN_FALSE_UNLESS(addition->rank() == 1);

auto addition_dim = addition->dim(0).value();
auto filter_channel_dim = filter->dim(0).value();
if (filter_channel_dim != addition_dim)
return false;
RETURN_FALSE_UNLESS(filter_channel_dim == addition_dim);

// fuse addition with transposed conv
// Fuse addition with transposed conv:
seanshpark marked this conversation as resolved.
Show resolved Hide resolved
tconv->bias(addition);

if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
Expand Down Expand Up @@ -140,12 +131,9 @@ bool FuseAddWithTConvPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
if (not tconv)
continue;

if (fuse_add_with_tconv(tconv))
changed = true;
if (auto add = dynamic_cast<luci::CircleAdd *>(node))
if (fuse_add_with_tconv(add))
changed = true;
}

return changed;
Expand Down
191 changes: 189 additions & 2 deletions compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright (c) 2021-2024 Samsung Electronics Co., Ltd. All Rights Reserved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,196 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/CreateCircleConst.h"

#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>

#include <gtest/gtest.h>

TEST(FuseAddWithTConvPassTest, name)
#define ADD_VAL 5.0f
namespace
{

using namespace luci::test;

/**
* Graph for this test
*
* BEFORE (without extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ |
* [CircleAdd w/ Relu]
* |
*
* BEFORE (with extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ | |
* [CircleAdd w/ Relu] [extra FC]
* | |
*
* AFTER (if pass was successful)
*
* |
* [CircleConst as bias] |
* \ |
* [CircleTransposeConv]
* |
* ([CircleRelu/Relu])
* |
*
*/
class TConvAddGraphlet
{
public:
void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias,
bool extra_successor)
{
_tconv = g->nodes()->create<luci::CircleTransposeConv>();

std::vector<float> input_sizes_val = {1, 4, 4, 1};
_tconv_i = luci::create_const_node(g, loco::DataType::FLOAT32, {4}, input_sizes_val);
_tconv->inputSizes(_tconv_i);

std::vector<float> filter_val(18);
for (uint32_t i = 0; i < 18; i++)
filter_val.at(i) = i;

_tconv_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 3, 3, 2}, filter_val);
_tconv->filter(_tconv_f);

if (use_bias)
{
std::vector<float> bias_val(1, 3.0f);
_tconv_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, bias_val);
}
else
{
// Create CircleOutputExclude -- no bias
_tconv_b = g->nodes()->create<luci::CircleOutputExclude>();
}
_tconv->bias(_tconv_b);

_tconv->padding(luci::Padding::VALID);
auto _stride = _tconv->stride();
_stride->w(1);
_stride->h(1);
_tconv->fusedActivationFunction(tconv_activation);
_tconv->dtype(loco::DataType::FLOAT32);
_tconv->shape({1, 4, 4, 1});
_tconv->name("tconv");

if (extra_successor)
{
_extra_succ = g->nodes()->create<luci::CircleFullyConnected>();
// Set previous TConv as input to bump number of successors for it:
_extra_succ->input(_tconv);
std::vector<float> weights_val(8);
_extra_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 8}, weights_val);
_extra_succ->weights(_extra_f);
_extra_succ->bias(nullptr);
_extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE);
_extra_succ->dtype(loco::DataType::FLOAT32);
_extra_succ->shape({1, 4, 4, 1});
_extra_succ->name("extra_fc");
}

std::vector<float> add_values(1, ADD_VAL);
_add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, add_values);
_add_c->name("const_c");

_add = g->nodes()->create<luci::CircleAdd>();
_add->x(_tconv);
_add->y(_add_c);
_add->fusedActivationFunction(luci::FusedActFunc::RELU);
_add->dtype(loco::DataType::FLOAT32);
_add->shape({1, 4, 4, 1});

_add->name("add");
}

protected:
luci::CircleTransposeConv *_tconv = nullptr;
luci::CircleConst *_tconv_i = nullptr;
luci::CircleConst *_tconv_f = nullptr;
luci::CircleNode *_tconv_b = nullptr;
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_c = nullptr;
luci::CircleFullyConnected *_extra_succ = nullptr;
luci::CircleConst *_extra_f = nullptr;
};

class FuseAddWithTConvTestGraph : public TestIOGraph, public TConvAddGraphlet
{
public:
void init(luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor)
{
TestIOGraph::init({1, 2, 2, 2}, {1, 4, 4, 1});
TConvAddGraphlet::init(g(), tconv_activation, use_bias, extra_successor);

_tconv->outBackprop(input());

output()->from(_add);
}
};

class FuseAddWithTConvPassTest : public ::testing::Test
{
public:
FuseAddWithTConvTestGraph g;
luci::FuseAddWithTConvPass pass;
};

} // namespace

TEST_F(FuseAddWithTConvPassTest, tconv_add_fuse)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

auto relu = dynamic_cast<luci::CircleRelu *>(g.output()->from());
EXPECT_NE(nullptr, relu);
EXPECT_STREQ(relu->name().c_str(), "const_c/Relu");

auto tconv = dynamic_cast<luci::CircleTransposeConv *>(relu->features());
EXPECT_NE(nullptr, tconv);

auto bias = loco::must_cast<luci::CircleConst *>(tconv->bias());
EXPECT_NE(nullptr, bias);

for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
EXPECT_EQ(ADD_VAL, bias->at<loco::DataType::FLOAT32>(i));
}
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_bias_NEG)
{
g.init(luci::FusedActFunc::NONE, true /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_activation_NEG)
{
g.init(luci::FusedActFunc::RELU, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_extra_successor_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, true /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, name)
{
luci::FuseAddWithTConvPass pass;
auto const name = pass.name();
Expand Down