From fdd1efed236472f73c001ed25a317fafb81a8a5d Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Thu, 22 Aug 2024 14:39:37 +0200 Subject: [PATCH 1/3] [luci/pass] Refactor FuseAddWithTConvPass This commit changes the order of searching for the pattern and adds tests for the pass. ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz --- .../luci/pass/src/FuseAddWithTConvPass.cpp | 84 ++++---- .../pass/src/FuseAddWithTConvPass.test.cpp | 192 +++++++++++++++++- 2 files changed, 226 insertions(+), 50 deletions(-) diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index d8e9f11f585..7032e6927b3 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2020-2024 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,18 @@ #include "luci/Pass/FuseAddWithTConvPass.h" +#include "helpers/NodeFiller.h" + #include #include namespace { + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + /** * Fuse Add to TransposeConv if possible * @@ -42,55 +49,39 @@ namespace * * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6 */ -bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) +bool fuse_add_with_tconv(luci::CircleAdd *add) { - // skip if tconv has fused activation - if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) - return false; - // check whether it has bias or not. This optimization works only if it doesn't. - auto bias = dynamic_cast(tconv->bias()); - if (not bias) - return false; - - // get weight of tconv - auto filter = dynamic_cast(tconv->filter()); - if (not filter) - return false; - if (filter->dtype() != loco::DataType::FLOAT32) - return false; - - // get add node - auto tconv_output = loco::succs(tconv); - assert(tconv_output.size() == 1); - auto add = dynamic_cast(*tconv_output.begin()); - if (not add) - return false; - if (add->dtype() != loco::DataType::FLOAT32) - return false; - if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6 && - add->fusedActivationFunction() != luci::FusedActFunc::RELU) - return false; - - // get addition + // Allow Add node only with FLOAT32 data type: + RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32); + + RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE || + add->fusedActivationFunction() == luci::FusedActFunc::RELU6 || + add->fusedActivationFunction() == luci::FusedActFunc::RELU); + // Find the pattern of Add(TransposeConv, CircleConst): + luci::CircleTransposeConv *tconv = nullptr; luci::CircleConst *addition = nullptr; - if (add->x() == tconv) - addition = dynamic_cast(add->y()); - else - addition = dynamic_cast(add->x()); + RETURN_FALSE_UNLESS(luci::fill(&tconv, &addition).with_commutative_args_of(add)); - if (not addition) - return false; + RETURN_FALSE_UNLESS(loco::succs(tconv).size() == 1); + + // Skip if tconv has fused activation. + RETURN_FALSE_UNLESS(tconv->fusedActivationFunction() == luci::FusedActFunc::NONE); + // Check whether tconv has bias or not. This optimization works only if it doesn't. + auto bias = dynamic_cast(tconv->bias()); + RETURN_FALSE_UNLESS(bias); + // Get weights of tconv: + auto filter = dynamic_cast(tconv->filter()); + RETURN_FALSE_UNLESS(filter); + RETURN_FALSE_UNLESS(filter->dtype() == loco::DataType::FLOAT32); // addition dim(0) == tconv filter channel dim - if (addition->rank() != 1) - return false; + RETURN_FALSE_UNLESS(addition->rank() == 1); + auto addition_dim = addition->dim(0).value(); auto filter_channel_dim = filter->dim(0).value(); - if (filter_channel_dim != addition_dim) - return false; + RETURN_FALSE_UNLESS(filter_channel_dim == addition_dim); - // fuse addition with transposed conv + // Fuse addition with transposed conv: tconv->bias(addition); if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) @@ -140,12 +131,9 @@ bool FuseAddWithTConvPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto tconv = dynamic_cast(node); - if (not tconv) - continue; - - if (fuse_add_with_tconv(tconv)) - changed = true; + if (auto add = dynamic_cast(node)) + if (fuse_add_with_tconv(add)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp index 8748d73efcd..8f0ef456fa5 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021-2024 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,197 @@ #include "luci/Pass/FuseAddWithTConvPass.h" +#include "helpers/CreateCircleConst.h" + +#include +#include + #include -TEST(FuseAddWithTConvPassTest, name) +#define ADD_VAL 5.0f +namespace +{ + +using namespace luci::test; + +/** + * Graph for this test + * + * BEFORE (without extra_successor) + * + * | + * [CircleConst] [CircleTransposeConv] + * \ | + * [CircleAdd w/ Relu] + * | + * + * BEFORE (with extra_successor) + * + * | + * [CircleConst] [CircleTransposeConv] + * \ | | + * [CircleAdd w/ Relu] [extra FC] + * | | + * + * AFTER (if pass was successful) + * + * | + * [CircleConst as bias] | + * \ | + * [CircleTransposeConv] + * | + * ([CircleRelu/Relu]) + * | + * + */ +class TConvAddGraphlet +{ +public: + // TODO: FIX all of this and testcases + void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias, + bool extra_successor) + { + _tconv = g->nodes()->create(); + + std::vector input_sizes_val = {1, 4, 4, 1}; + _tconv_i = luci::create_const_node(g, loco::DataType::FLOAT32, {4}, input_sizes_val); + _tconv->inputSizes(_tconv_i); + + std::vector filter_val(18); + for (uint32_t i = 0; i < 18; i++) + filter_val.at(i) = i; + + _tconv_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 3, 3, 2}, filter_val); + _tconv->filter(_tconv_f); + + if (use_bias) + { + std::vector bias_val(1, 3.0f); + _tconv_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, bias_val); + } + else + { + // Create CircleOutputExclude -- no bias + _tconv_b = g->nodes()->create(); + } + _tconv->bias(_tconv_b); + + _tconv->padding(luci::Padding::VALID); + auto _stride = _tconv->stride(); + _stride->w(1); + _stride->h(1); + _tconv->fusedActivationFunction(tconv_activation); + _tconv->dtype(loco::DataType::FLOAT32); + _tconv->shape({1, 4, 4, 1}); + _tconv->name("tconv"); + + if (extra_successor) + { + _extra_succ = g->nodes()->create(); + // Set previous TConv as input to bump number of successors for it: + _extra_succ->input(_tconv); + std::vector weights_val(8); + _extra_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 8}, weights_val); + _extra_succ->weights(_extra_f); + _extra_succ->bias(nullptr); + _extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE); + _extra_succ->dtype(loco::DataType::FLOAT32); + _extra_succ->shape({1, 4, 4, 1}); + _extra_succ->name("extra_fc"); + } + + std::vector add_values(1, ADD_VAL); + _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, add_values); + _add_c->name("const_c"); + + _add = g->nodes()->create(); + _add->x(_tconv); + _add->y(_add_c); + _add->fusedActivationFunction(luci::FusedActFunc::RELU); + _add->dtype(loco::DataType::FLOAT32); + _add->shape({1, 4, 4, 1}); + + _add->name("add"); + } + +protected: + luci::CircleTransposeConv *_tconv = nullptr; + luci::CircleConst *_tconv_i = nullptr; + luci::CircleConst *_tconv_f = nullptr; + luci::CircleNode *_tconv_b = nullptr; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_add_c = nullptr; + luci::CircleFullyConnected *_extra_succ = nullptr; + luci::CircleConst *_extra_f = nullptr; +}; + +class FuseAddWithTConvTestGraph : public TestIOGraph, public TConvAddGraphlet +{ +public: + void init(luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor) + { + TestIOGraph::init({1, 2, 2, 2}, {1, 4, 4, 1}); + TConvAddGraphlet::init(g(), tconv_activation, use_bias, extra_successor); + + _tconv->outBackprop(input()); + + output()->from(_add); + } +}; + +class FuseAddWithTConvPassTest : public ::testing::Test +{ +public: + FuseAddWithTConvTestGraph g; + luci::FuseAddWithTConvPass pass; +}; + +} // namespace + +TEST_F(FuseAddWithTConvPassTest, tconv_add_fuse) +{ + g.init(luci::FusedActFunc::NONE, false /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto relu = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, relu); + EXPECT_STREQ(relu->name().c_str(), "const_c/Relu"); + + auto tconv = dynamic_cast(relu->features()); + EXPECT_NE(nullptr, tconv); + + auto bias = loco::must_cast(tconv->bias()); + EXPECT_NE(nullptr, bias); + + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(ADD_VAL, bias->at(i)); + } +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_bias_NEG) +{ + g.init(luci::FusedActFunc::NONE, true /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_activation_NEG) +{ + g.init(luci::FusedActFunc::RELU, false /* use_bias */, false /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, tconv_with_extra_successor_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* use_bias */, true /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseAddWithTConvPassTest, name) { luci::FuseAddWithTConvPass pass; auto const name = pass.name(); From 6b5d651d699139551318be98bc9bd05e7d8c23fd Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Thu, 22 Aug 2024 16:45:11 +0200 Subject: [PATCH 2/3] Remove comment --- compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp index 8f0ef456fa5..eede887ad95 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp @@ -62,7 +62,6 @@ using namespace luci::test; class TConvAddGraphlet { public: - // TODO: FIX all of this and testcases void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor) { From 3d8648eecf879f32627be2b254a8b2bb456fc1c4 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 28 Aug 2024 11:43:28 +0200 Subject: [PATCH 3/3] Refactor comments and copyright headers --- compiler/luci/pass/src/FuseAddWithTConvPass.cpp | 17 +++++++++-------- .../luci/pass/src/FuseAddWithTConvPass.test.cpp | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index 7032e6927b3..3563aa8fe52 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,9 +51,9 @@ namespace */ bool fuse_add_with_tconv(luci::CircleAdd *add) { - // Allow Add node only with FLOAT32 data type: + // Allow Add node only with FLOAT32 data type. RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32); - + // Allow Add node only with specific activations. RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE || add->fusedActivationFunction() == luci::FusedActFunc::RELU6 || add->fusedActivationFunction() == luci::FusedActFunc::RELU); @@ -88,34 +88,35 @@ bool fuse_add_with_tconv(luci::CircleAdd *add) { auto name = addition->name(); assert(name.length() > 0); - // separate relu op from add op + // Separate relu op from add op: auto relu = add->graph()->nodes()->create(); relu->features(tconv); relu->name(name + "/Relu6"); luci::add_origin(relu, luci::get_origin(add)); - // remove add node + // Remove add node. replace(add).with(relu); } else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU) { auto name = addition->name(); assert(name.length() > 0); - // separate relu op from add op + // Separate relu op from add op: auto relu = add->graph()->nodes()->create(); relu->features(tconv); relu->name(name + "/Relu"); luci::add_origin(relu, luci::get_origin(add)); - // remove add node + // Remove add node. replace(add).with(relu); } else { + // Remove add node. replace(add).with(tconv); } - // set origin + // Set new origin. luci::add_origin(tconv, luci::get_origin(add)); return true; diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp index eede887ad95..3bf7d112aa2 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.