-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846
base: master
Are you sure you want to change the base?
[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846
Conversation
This commit changes the order of searching for the pattern. ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz <[email protected]>
For: #13685 |
|
||
auto fused_bias = luci::clone(addition); | ||
|
||
// Add existing bias values | ||
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias())) | ||
{ | ||
assert(const_bias->dtype() == loco::DataType::FLOAT32); | ||
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you check this inside common_pass_checks
? luci::clone
adds a node into the graph (graph is modified), so all checks had to be done before luci::clone
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this check is specific to float32.. It would be better to use a pattern class (PTAL how FuseGeluPass
finds different patterns).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinevening could this refactor be done in separate PR? In my opinion, it would be better to push it as-is and create another PR that will introduce pattern search like FuseGeluPass
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if (not(cond)) \ | ||
return false; | ||
|
||
bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q) why use pointers of pointer? add
is not updated here.
It would be better not to use double pointer if not necessary
and also use reference if need update in this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanshpark as for add
I fully understand the concern. This is changed.
As for double pointers, it was all required by luci::fill
to get proper handling (i.e. do not loose pointer locally in the function, which resulted in segfaults). Now I propose use of struct to keep track of nodes in the pattern. I see it as proper middle-ground solution, that also organize the required nodes in one place. What do you think?
|
||
RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add)); | ||
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q) Why are you using const luci::CircleAdd &
?
|
||
RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add)); | ||
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing loco::DataType
,
template <typename DT>
bool fc_with_add_pattern_check(const luci::CircleAdd &add,
?
// TODO Support scalar addition | ||
if (rank == 0) | ||
return false; | ||
PatternNodes nodes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems you want to find members of this nodes
.
there are similar styles in other Pass that you can refer.
plz check compiler/luci/pass/src/FuseRmsNormPass.cpp
that was recently added.
This commit changes the order of searching for the pattern.
ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz [email protected]