-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[one-optimize] Fuse Mul with FullyConnected layer #13528
Closed
Closed
Changes from 4 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
28d8089
[one-optimize] Fuse Mul with FullyConnected layer
jiwaszki 87774c5
Move mul_with_fully_connected pass after the mul_with_div
jiwaszki 4c7f5a9
Remove weights constant check
jiwaszki cfcb68b
Change order of updating the nodes, more consuming one is now later
jiwaszki e8b06b5
Fix values updating and add luci tests
jiwaszki 3bf2649
Fix codestyle
jiwaszki d386541
Rename pass
jiwaszki 4180734
Add luci tests with models
jiwaszki 761d303
Fix scalar vs multi-dim case
jiwaszki fa4733b
Separate bias and weights updating, remove checks
jiwaszki 40dcacf
[luci/pass] Introduce FuseMulWithFullyConnectedPass
jiwaszki b717234
[one-cmds] Add an option for FuseMulWithFullyConnectedPass
jiwaszki a568d25
[circle2circle] Dredd test for FuseMulWithFullyConnectedPass
jiwaszki d3246e3
[luci/pass] Value test for FuseMulWithFullyConnectedPass
jiwaszki f661561
Change constness of args, move tests and move FuseMulWithFC after Fus…
jiwaszki 85d9783
Fix codestyle
jiwaszki 51dd43c
Fix order of cmds
jiwaszki e3b354e
Remove default arguments
jiwaszki ffc36e9
Remove default args
jiwaszki 835126a
Merge remote-tracking branch 'upstream/master' into jiwaszki/fuse_mul_fc
jiwaszki 31e25ed
Refactor solution and apply comments
jiwaszki 396d733
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki d5ec1d8
Merge branch 'jiwaszki/fuse_mul_fc_one_cmds' into jiwaszki/fuse_mul_fc
jiwaszki 8b17f47
Add handling of no bias case to pass
jiwaszki 8d90e50
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 715cdf7
Remove random newline
jiwaszki 9e22b26
Apply comments, refactor tests and add proper handling of OUTPUTEXCLUDE
jiwaszki 62a09a0
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 1b6c71f
Resolve one-cmds duplication
jiwaszki 8977ef9
Handle rank 0 and 1
jiwaszki dbed1b9
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 53aa943
Add new testcase
jiwaszki 0c2bb71
Add new testcase
jiwaszki e3ff517
[res/tfl_recipes] Add new Net_FullyConnected_Mul
jiwaszki d6e8b4a
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki b4ebd44
Merge branch 'jiwaszki/fuse_mul_fc_c2c_dredd' into jiwaszki/fuse_mul_fc
jiwaszki cddd353
Merge branch 'jiwaszki/fuse_mul_fc_luci_test' into jiwaszki/fuse_mul_fc
jiwaszki 678869c
Change name of operand from B to scale
jiwaszki af3119d
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki b085181
Update names from scalar to single element
jiwaszki 1aa79cc
Update tests
jiwaszki f02cb88
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 1bb278d
Fix codestyle
jiwaszki 27dec03
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 79a2213
Search from mul, update tests
jiwaszki 7ea759a
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 550e798
Annotate requirement of one successor and refactor checks
jiwaszki bda96d8
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,21 +29,38 @@ namespace | |
if (not(cond)) \ | ||
return false; | ||
|
||
inline void update_values(luci::CircleConst *fused_node, luci::CircleConst *multiplication) | ||
inline void update_values(luci::CircleConst *fused_node, luci::CircleConst *multiplication, bool is_weights) | ||
{ | ||
auto node_size = fused_node->size<loco::DataType::FLOAT32>(); | ||
// Scalar: | ||
auto mul_size = multiplication->size<loco::DataType::FLOAT32>(); | ||
// Scalar multiplication: | ||
if (multiplication->rank() == 1 || | ||
multiplication->rank() == 0 && multiplication->size<loco::DataType::FLOAT32>() == 1) | ||
multiplication->rank() == 0 && mul_size == 1) | ||
{ | ||
for (uint32_t i = 0; i < node_size; i++) | ||
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0); | ||
} | ||
// N-size: | ||
// N-size multiplication: | ||
else | ||
{ | ||
for (uint32_t i = 0; i < node_size; i++) | ||
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i); | ||
// Go along channels, multiplication size is ensured to be compatible with channels. | ||
if (is_weights) // weights 2-D | ||
{ | ||
auto count = fused_node->dim(0).value(); | ||
auto size = fused_node->dim(fused_node->rank() - 1).value(); | ||
float val; | ||
for (uint32_t c = 0; c < count; c++) { | ||
val = multiplication->at<loco::DataType::FLOAT32>(c); | ||
for (uint32_t i = 0; i < size; i++) { | ||
fused_node->at<loco::DataType::FLOAT32>(c * size + i) *= val; | ||
} | ||
} | ||
} | ||
else // bias 1-D | ||
{ | ||
for (uint32_t i = 0; i < node_size; i++) | ||
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -98,26 +115,15 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) | |
// Check channel-wise broadcasting: | ||
for (uint32_t i = 0; i < rank - 1; i++) | ||
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1); | ||
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected | ||
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); | ||
} | ||
// Scalar case: | ||
else if (multiplication->rank() == 1 || multiplication->rank() == 0) | ||
{ | ||
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() != 0); | ||
} | ||
|
||
// Update weights accordingly. | ||
RETURN_FALSE_UNLESS(weights->opcode() == luci::CircleOpcode::CIRCLECONST or | ||
weights->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) | ||
// Create new weights to be updated with values: | ||
auto fused_weights = luci::clone(weights); | ||
RETURN_FALSE_UNLESS(fused_weights->size<loco::DataType::FLOAT32>() == | ||
weights->size<loco::DataType::FLOAT32>()); | ||
|
||
update_values(fused_weights, multiplication); | ||
|
||
fc->weights(fused_weights); | ||
|
||
// Update bias accordingly. | ||
// Only supports: | ||
// (1) constant bias | ||
// (2) no bias | ||
|
@@ -133,14 +139,18 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) | |
RETURN_FALSE_UNLESS(fused_bias->size<loco::DataType::FLOAT32>() == | ||
const_bias->size<loco::DataType::FLOAT32>()); | ||
|
||
update_values(fused_bias, multiplication); | ||
// Create new weights to be updated with values: | ||
auto fused_weights = luci::clone(weights); | ||
RETURN_FALSE_UNLESS(fused_weights->size<loco::DataType::FLOAT32>() == | ||
weights->size<loco::DataType::FLOAT32>()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems unnecessary. |
||
|
||
// Here fused_bias's shape is either [1, 1, ..., N] or [N] | ||
// where N is weights->dim(0). | ||
// The shape is normalized to [N] to become the bias of FullyConected. | ||
fused_bias->rank(1); | ||
fused_bias->dim(0) = weights->dim(0); | ||
// Update bias accordingly: | ||
update_values(fused_bias, multiplication, false); | ||
// Update weights accordingly: | ||
update_values(fused_weights, multiplication, true); | ||
|
||
// Replace weights and bias: | ||
fc->weights(fused_weights); | ||
fc->bias(fused_bias); | ||
|
||
// Set origin and copy Activation Function if exisitng: | ||
|
218 changes: 218 additions & 0 deletions
218
compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FuseMulWithFullyConnectedPass.h" | ||
#include "helpers/CreateCircleConst.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
#include <luci/test/TestIOGraph.h> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#define DIM_ONE 8 | ||
#define DIM_TWO 4 | ||
#define MUL_VAL 2.0f | ||
|
||
namespace | ||
{ | ||
|
||
using namespace luci::test; | ||
|
||
/** | ||
* Graph for this test | ||
* | ||
* BEFORE | ||
* | ||
* [FC] | ||
* | | ||
* [Mul w/ Relu] | ||
* | ||
* AFTER | ||
* | ||
* [FC w/ Relu] (weights and bias updated) | ||
* | ||
*/ | ||
class FCMulGraphlet | ||
{ | ||
public: | ||
FCMulGraphlet() = default; | ||
|
||
void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar) | ||
{ | ||
std::vector<float> weights_val(DIM_ONE * DIM_TWO); | ||
for (uint32_t i = 0; i < DIM_ONE * DIM_TWO; i++) | ||
weights_val.at(i) = i; | ||
|
||
_fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); | ||
|
||
std::vector<float> bias_val(DIM_ONE); | ||
for (uint32_t i = 0; i < DIM_ONE; i++) | ||
bias_val.at(i) = i; | ||
|
||
_fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); | ||
|
||
_fc = g->nodes()->create<luci::CircleFullyConnected>(); | ||
_fc->weights(_fc_f); | ||
_fc->bias(_fc_b); | ||
_fc->fusedActivationFunction(fc_activation); | ||
_fc->dtype(loco::DataType::FLOAT32); | ||
_fc->shape({1, DIM_ONE}); | ||
_fc->name("fc"); | ||
|
||
std::vector<float> mul_values; | ||
|
||
if(is_mul_scalar) { | ||
mul_values.push_back(static_cast<float>(MUL_VAL)); | ||
_mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, mul_values); | ||
} | ||
else { | ||
for (uint32_t i = 0; i < DIM_ONE; i++) { | ||
mul_values.push_back(static_cast<float>(i)); | ||
} | ||
_mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 1, 1, DIM_ONE}, mul_values); | ||
} | ||
|
||
_mul = g->nodes()->create<luci::CircleMul>(); | ||
_mul->x(_fc); | ||
_mul->y(_mul_c); | ||
_mul->fusedActivationFunction(luci::FusedActFunc::RELU); | ||
_mul->dtype(loco::DataType::FLOAT32); | ||
if(is_mul_scalar) { | ||
_mul->shape({1}); | ||
} | ||
else { | ||
_mul->shape({1, DIM_ONE}); | ||
} | ||
_mul->name("mul"); | ||
} | ||
|
||
public: | ||
luci::CircleFullyConnected *fc() { return _fc; } | ||
|
||
void to_fm_bias(void) | ||
{ | ||
assert(_fc != nullptr); | ||
|
||
auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>(); | ||
_fc->bias(new_fc); | ||
} | ||
|
||
protected: | ||
luci::CircleFullyConnected *_fc = nullptr; | ||
luci::CircleMul *_mul = nullptr; | ||
luci::CircleConst *_fc_f = nullptr; | ||
luci::CircleConst *_fc_b = nullptr; | ||
luci::CircleConst *_mul_c = nullptr; | ||
}; | ||
|
||
class FuseAddWithFCTestGraph : public TestIOGraph, public FCMulGraphlet | ||
{ | ||
public: | ||
FuseAddWithFCTestGraph() = default; | ||
|
||
void init(luci::FusedActFunc fc_activation = luci::FusedActFunc::NONE, bool is_mul_scalar = false) | ||
{ | ||
TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); | ||
FCMulGraphlet::init(g(), fc_activation, is_mul_scalar); | ||
|
||
_fc->input(input()); | ||
|
||
output()->from(_mul); | ||
} | ||
}; | ||
|
||
class FuseMulWithFullyConnectedPassTest : public ::testing::Test | ||
{ | ||
public: | ||
FuseAddWithFCTestGraph g; | ||
luci::FuseMulWithFullyConnectedPass pass; | ||
}; | ||
|
||
TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) | ||
{ | ||
g.init(luci::FusedActFunc::NONE, false); | ||
|
||
EXPECT_EQ(true, pass.run(g.g())); | ||
|
||
auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from()); | ||
EXPECT_NE(nullptr, fc); | ||
|
||
auto weights = loco::must_cast<luci::CircleConst *>(g.fc()->weights()); | ||
auto weights_n = weights->dim(0).value(); | ||
auto weights_m = weights->dim(1).value(); | ||
uint32_t offset = 0; | ||
for (uint32_t i = 0; i < weights_n; i++) | ||
{ | ||
for (uint32_t j = 0; j < weights_m; j++) | ||
{ | ||
offset = i * weights_m + j; | ||
EXPECT_EQ(i * offset, weights->at<loco::DataType::FLOAT32>(offset)); | ||
} | ||
} | ||
|
||
auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias()); | ||
for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
EXPECT_EQ(i * i, bias->at<loco::DataType::FLOAT32>(i)); | ||
} | ||
} | ||
|
||
TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) | ||
{ | ||
g.init(luci::FusedActFunc::NONE, true); | ||
|
||
EXPECT_EQ(true, pass.run(g.g())); | ||
|
||
auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from()); | ||
EXPECT_NE(nullptr, fc); | ||
|
||
auto weights = loco::must_cast<luci::CircleConst *>(g.fc()->weights()); | ||
auto weights_n = weights->dim(0).value(); | ||
auto weights_m = weights->dim(1).value(); | ||
uint32_t offset = 0; | ||
for (uint32_t i = 0; i < weights_n; i++) | ||
{ | ||
for (uint32_t j = 0; j < weights_m; j++) | ||
{ | ||
offset = i * weights_m + j; | ||
EXPECT_EQ(MUL_VAL * offset, weights->at<loco::DataType::FLOAT32>(offset)); | ||
} | ||
} | ||
|
||
auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias()); | ||
for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
EXPECT_EQ(MUL_VAL * i, bias->at<loco::DataType::FLOAT32>(i)); | ||
} | ||
} | ||
|
||
TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) | ||
{ | ||
g.init(); | ||
|
||
// Bias cannot be fused as it's passed as feature map. | ||
g.to_fm_bias(); | ||
|
||
EXPECT_EQ(false, pass.run(g.g())); | ||
} | ||
|
||
TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) | ||
{ | ||
g.init(luci::FusedActFunc::RELU); | ||
|
||
EXPECT_EQ(false, pass.run(g.g())); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to follow other options
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! Done in scope of d386541