Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[one-optimize] Fuse Mul with FullyConnected layer #13528

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
28d8089
[one-optimize] Fuse Mul with FullyConnected layer
jiwaszki Jul 26, 2024
87774c5
Move mul_with_fully_connected pass after the mul_with_div
jiwaszki Jul 29, 2024
4c7f5a9
Remove weights constant check
jiwaszki Jul 29, 2024
cfcb68b
Change order of updating the nodes, more consuming one is now later
jiwaszki Jul 29, 2024
e8b06b5
Fix values updating and add luci tests
jiwaszki Aug 1, 2024
3bf2649
Fix codestyle
jiwaszki Aug 2, 2024
d386541
Rename pass
jiwaszki Aug 5, 2024
4180734
Add luci tests with models
jiwaszki Aug 5, 2024
761d303
Fix scalar vs multi-dim case
jiwaszki Aug 5, 2024
fa4733b
Separate bias and weights updating, remove checks
jiwaszki Aug 6, 2024
40dcacf
[luci/pass] Introduce FuseMulWithFullyConnectedPass
jiwaszki Aug 6, 2024
b717234
[one-cmds] Add an option for FuseMulWithFullyConnectedPass
jiwaszki Aug 6, 2024
a568d25
[circle2circle] Dredd test for FuseMulWithFullyConnectedPass
jiwaszki Aug 7, 2024
d3246e3
[luci/pass] Value test for FuseMulWithFullyConnectedPass
jiwaszki Aug 7, 2024
f661561
Change constness of args, move tests and move FuseMulWithFC after Fus…
jiwaszki Aug 7, 2024
85d9783
Fix codestyle
jiwaszki Aug 7, 2024
51dd43c
Fix order of cmds
jiwaszki Aug 7, 2024
e3b354e
Remove default arguments
jiwaszki Aug 8, 2024
ffc36e9
Remove default args
jiwaszki Aug 8, 2024
835126a
Merge remote-tracking branch 'upstream/master' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
31e25ed
Refactor solution and apply comments
jiwaszki Aug 9, 2024
396d733
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
d5ec1d8
Merge branch 'jiwaszki/fuse_mul_fc_one_cmds' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
8b17f47
Add handling of no bias case to pass
jiwaszki Aug 9, 2024
8d90e50
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
715cdf7
Remove random newline
jiwaszki Aug 9, 2024
9e22b26
Apply comments, refactor tests and add proper handling of OUTPUTEXCLUDE
jiwaszki Aug 12, 2024
62a09a0
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
1b6c71f
Resolve one-cmds duplication
jiwaszki Aug 12, 2024
8977ef9
Handle rank 0 and 1
jiwaszki Aug 12, 2024
dbed1b9
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
53aa943
Add new testcase
jiwaszki Aug 12, 2024
0c2bb71
Add new testcase
jiwaszki Aug 12, 2024
e3ff517
[res/tfl_recipes] Add new Net_FullyConnected_Mul
jiwaszki Aug 12, 2024
d6e8b4a
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki Aug 12, 2024
b4ebd44
Merge branch 'jiwaszki/fuse_mul_fc_c2c_dredd' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
cddd353
Merge branch 'jiwaszki/fuse_mul_fc_luci_test' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
678869c
Change name of operand from B to scale
jiwaszki Aug 13, 2024
af3119d
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki Aug 13, 2024
b085181
Update names from scalar to single element
jiwaszki Aug 13, 2024
1aa79cc
Update tests
jiwaszki Aug 13, 2024
f02cb88
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 13, 2024
1bb278d
Fix codestyle
jiwaszki Aug 13, 2024
27dec03
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 13, 2024
79a2213
Search from mul, update tests
jiwaszki Aug 14, 2024
7ea759a
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 14, 2024
550e798
Annotate requirement of one successor and refactor checks
jiwaszki Aug 19, 2024
bda96d8
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ int entry(int argc, char **argv)
add_switch(arser, "--fuse_mean_with_mean",
"This will fuse two Mean operations when they follow one by one. This will fold them "
"into one operation and merge reduction indices.");
add_switch(arser, "--fuse_mul_with_fully_connected",
"This will fuse Mul operator to FullyConnected operator");
add_switch(arser, "--fuse_mul_to_fullyconnected_weights",
"This will fuse Mul to following FullyConnected weights");
add_switch(arser, "--fuse_mul_with_conv",
"This will fuse Mul operation with a preceding Conv if possible.");
add_switch(arser, "--fuse_mul_with_div",
"This will fuse Mul operation with a Div operation whose numerator is const.");
add_switch(arser, "--fuse_mul_with_fully_connected",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
add_switch(arser, "--fuse_mul_with_fully_connected",
add_switch(arser, "--fuse_mul_with_fullyconnected",

to follow other options

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Done in scope of d386541

"This will fuse Mul operator with a preceding FullyConnected operator.");
add_switch(arser, "--fuse_slice_with_tconv",
"This will fuse Slice operation with a preceding TConv if possible.");
add_switch(arser, "--fuse_transpose_with_mean",
Expand Down Expand Up @@ -314,8 +314,6 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseBatchNormWithDwConv);
if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
options->enable(Algorithms::FuseBatchNormWithTConv);
if (arser.get<bool>("--fuse_mul_with_fully_connected"))
options->enable(Algorithms::FuseMulWithFullyConnected);
if (arser.get<bool>("--fuse_mul_to_fullyconnected_weights"))
options->enable(Algorithms::FuseMulToFullyConnectedWeights);
if (arser.get<bool>("--fuse_slice_with_tconv"))
Expand All @@ -330,6 +328,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseMulWithConv);
if (arser.get<bool>("--fuse_mul_with_div"))
options->enable(Algorithms::FuseMulWithDiv);
if (arser.get<bool>("--fuse_mul_with_fully_connected"))
options->enable(Algorithms::FuseMulWithFullyConnected);
if (arser.get<bool>("--make_batchnorm_gamma_positive"))
options->enable(Algorithms::MakeBatchNormGammaPositive);
if (arser.get<bool>("--fuse_preactivation_batchnorm"))
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class CircleOptimizer final
FuseBatchNormWithConv,
FuseBatchNormWithDwConv,
FuseBatchNormWithTConv,
FuseMulWithFullyConnected,
FuseMulToFullyConnectedWeights,
FuseSliceWithTConv,
FuseBCQ,
Expand All @@ -50,6 +49,7 @@ class CircleOptimizer final
FuseMeanWithMean,
FuseMulWithConv,
FuseMulWithDiv,
FuseMulWithFullyConnected,
FuseTransposeWithMean,
ResolveCustomOpAdd,
ResolveCustomOpBatchMatMul,
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
#include "luci/Pass/FuseBatchNormWithDwConvPass.h"
#include "luci/Pass/FuseBatchNormWithTConvPass.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FuseMulWithConvPass.h"
#include "luci/Pass/FuseMulWithDivPass.h"
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FusePReluPass.h"
#include "luci/Pass/FuseGeluPass.h"
Expand Down
60 changes: 35 additions & 25 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,38 @@ namespace
if (not(cond)) \
return false;

inline void update_values(luci::CircleConst *fused_node, luci::CircleConst *multiplication)
inline void update_values(luci::CircleConst *fused_node, luci::CircleConst *multiplication, bool is_weights)
{
auto node_size = fused_node->size<loco::DataType::FLOAT32>();
// Scalar:
auto mul_size = multiplication->size<loco::DataType::FLOAT32>();
// Scalar multiplication:
if (multiplication->rank() == 1 ||
multiplication->rank() == 0 && multiplication->size<loco::DataType::FLOAT32>() == 1)
multiplication->rank() == 0 && mul_size == 1)
{
for (uint32_t i = 0; i < node_size; i++)
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0);
}
// N-size:
// N-size multiplication:
else
{
for (uint32_t i = 0; i < node_size; i++)
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i);
// Go along channels, multiplication size is ensured to be compatible with channels.
if (is_weights) // weights 2-D
{
auto count = fused_node->dim(0).value();
auto size = fused_node->dim(fused_node->rank() - 1).value();
float val;
for (uint32_t c = 0; c < count; c++) {
val = multiplication->at<loco::DataType::FLOAT32>(c);
for (uint32_t i = 0; i < size; i++) {
fused_node->at<loco::DataType::FLOAT32>(c * size + i) *= val;
}
}
}
else // bias 1-D
{
for (uint32_t i = 0; i < node_size; i++)
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i);
}
}
}

Expand Down Expand Up @@ -98,26 +115,15 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
// Check channel-wise broadcasting:
for (uint32_t i = 0; i < rank - 1; i++)
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1);
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0));
}
// Scalar case:
else if (multiplication->rank() == 1 || multiplication->rank() == 0)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() != 0);
}

// Update weights accordingly.
RETURN_FALSE_UNLESS(weights->opcode() == luci::CircleOpcode::CIRCLECONST or
weights->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
// Create new weights to be updated with values:
auto fused_weights = luci::clone(weights);
RETURN_FALSE_UNLESS(fused_weights->size<loco::DataType::FLOAT32>() ==
weights->size<loco::DataType::FLOAT32>());

update_values(fused_weights, multiplication);

fc->weights(fused_weights);

// Update bias accordingly.
// Only supports:
// (1) constant bias
// (2) no bias
Expand All @@ -133,14 +139,18 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
RETURN_FALSE_UNLESS(fused_bias->size<loco::DataType::FLOAT32>() ==
const_bias->size<loco::DataType::FLOAT32>());

update_values(fused_bias, multiplication);
// Create new weights to be updated with values:
auto fused_weights = luci::clone(weights);
RETURN_FALSE_UNLESS(fused_weights->size<loco::DataType::FLOAT32>() ==
weights->size<loco::DataType::FLOAT32>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unnecessary.


// Here fused_bias's shape is either [1, 1, ..., N] or [N]
// where N is weights->dim(0).
// The shape is normalized to [N] to become the bias of FullyConected.
fused_bias->rank(1);
fused_bias->dim(0) = weights->dim(0);
// Update bias accordingly:
update_values(fused_bias, multiplication, false);
// Update weights accordingly:
update_values(fused_weights, multiplication, true);

// Replace weights and bias:
fc->weights(fused_weights);
fc->bias(fused_bias);

// Set origin and copy Activation Function if exisitng:
Expand Down
218 changes: 218 additions & 0 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "helpers/CreateCircleConst.h"

#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>

#include <gtest/gtest.h>

#define DIM_ONE 8
#define DIM_TWO 4
#define MUL_VAL 2.0f

namespace
{

using namespace luci::test;

/**
* Graph for this test
*
* BEFORE
*
* [FC]
* |
* [Mul w/ Relu]
*
* AFTER
*
* [FC w/ Relu] (weights and bias updated)
*
*/
class FCMulGraphlet
{
public:
FCMulGraphlet() = default;

void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar)
{
std::vector<float> weights_val(DIM_ONE * DIM_TWO);
for (uint32_t i = 0; i < DIM_ONE * DIM_TWO; i++)
weights_val.at(i) = i;

_fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val);

std::vector<float> bias_val(DIM_ONE);
for (uint32_t i = 0; i < DIM_ONE; i++)
bias_val.at(i) = i;

_fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val);

_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->weights(_fc_f);
_fc->bias(_fc_b);
_fc->fusedActivationFunction(fc_activation);
_fc->dtype(loco::DataType::FLOAT32);
_fc->shape({1, DIM_ONE});
_fc->name("fc");

std::vector<float> mul_values;

if(is_mul_scalar) {
mul_values.push_back(static_cast<float>(MUL_VAL));
_mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, mul_values);
}
else {
for (uint32_t i = 0; i < DIM_ONE; i++) {
mul_values.push_back(static_cast<float>(i));
}
_mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 1, 1, DIM_ONE}, mul_values);
}

_mul = g->nodes()->create<luci::CircleMul>();
_mul->x(_fc);
_mul->y(_mul_c);
_mul->fusedActivationFunction(luci::FusedActFunc::RELU);
_mul->dtype(loco::DataType::FLOAT32);
if(is_mul_scalar) {
_mul->shape({1});
}
else {
_mul->shape({1, DIM_ONE});
}
_mul->name("mul");
}

public:
luci::CircleFullyConnected *fc() { return _fc; }

void to_fm_bias(void)
{
assert(_fc != nullptr);

auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>();
_fc->bias(new_fc);
}

protected:
luci::CircleFullyConnected *_fc = nullptr;
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_fc_f = nullptr;
luci::CircleConst *_fc_b = nullptr;
luci::CircleConst *_mul_c = nullptr;
};

class FuseAddWithFCTestGraph : public TestIOGraph, public FCMulGraphlet
{
public:
FuseAddWithFCTestGraph() = default;

void init(luci::FusedActFunc fc_activation = luci::FusedActFunc::NONE, bool is_mul_scalar = false)
{
TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE});
FCMulGraphlet::init(g(), fc_activation, is_mul_scalar);

_fc->input(input());

output()->from(_mul);
}
};

class FuseMulWithFullyConnectedPassTest : public ::testing::Test
{
public:
FuseAddWithFCTestGraph g;
luci::FuseMulWithFullyConnectedPass pass;
};

TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar)
{
g.init(luci::FusedActFunc::NONE, false);

EXPECT_EQ(true, pass.run(g.g()));

auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
EXPECT_NE(nullptr, fc);

auto weights = loco::must_cast<luci::CircleConst *>(g.fc()->weights());
auto weights_n = weights->dim(0).value();
auto weights_m = weights->dim(1).value();
uint32_t offset = 0;
for (uint32_t i = 0; i < weights_n; i++)
{
for (uint32_t j = 0; j < weights_m; j++)
{
offset = i * weights_m + j;
EXPECT_EQ(i * offset, weights->at<loco::DataType::FLOAT32>(offset));
}
}

auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
EXPECT_EQ(i * i, bias->at<loco::DataType::FLOAT32>(i));
}
}

TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar)
{
g.init(luci::FusedActFunc::NONE, true);

EXPECT_EQ(true, pass.run(g.g()));

auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
EXPECT_NE(nullptr, fc);

auto weights = loco::must_cast<luci::CircleConst *>(g.fc()->weights());
auto weights_n = weights->dim(0).value();
auto weights_m = weights->dim(1).value();
uint32_t offset = 0;
for (uint32_t i = 0; i < weights_n; i++)
{
for (uint32_t j = 0; j < weights_m; j++)
{
offset = i * weights_m + j;
EXPECT_EQ(MUL_VAL * offset, weights->at<loco::DataType::FLOAT32>(offset));
}
}

auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
EXPECT_EQ(MUL_VAL * i, bias->at<loco::DataType::FLOAT32>(i));
}
}

TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG)
{
g.init();

// Bias cannot be fused as it's passed as feature map.
g.to_fm_bias();

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG)
{
g.init(luci::FusedActFunc::RELU);

EXPECT_EQ(false, pass.run(g.g()));
}
}
Loading