-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
on-going draft to support new tflite/circle schema. Signed-off-by: SaeHie Park <[email protected]>
- Loading branch information
1 parent
547bd7d
commit 937ef97
Showing
11 changed files
with
590 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
compiler/luci/pass/include/luci/Pass/XpSepActFromTransposeConvPass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ | ||
#define __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ | ||
|
||
#include <logo/Pass.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* @brief Experimental Class to separate activation functions from TransposeConv | ||
*/ | ||
struct XpSepActFromTransposeConvPass final : public logo::Pass | ||
{ | ||
const char *name(void) const final { return "luci::XpSepActFromTransposeConvPass"; } | ||
|
||
bool run(loco::Graph *g) final; | ||
}; | ||
|
||
} // namespace luci | ||
|
||
#endif // __LUCI_XP_SEP_ACT_FROM_TRANSPOSE_CONV_PASS_H__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
compiler/luci/pass/src/XpSepActFromTransposeConvPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/XpSepActFromTransposeConvPass.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
#include <luci/IR/CircleNodeMixins.h> | ||
#include <luci/Profile/CircleNodeOrigin.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* XpSepActFromTransposeConvPass | ||
* - Experimental Separate Activation From TransposeConv | ||
* - This pass exist temporary to separate activation function from | ||
* - TransposeConv to support backends that don't support this. | ||
* - This pass will be removed when all backends support fused activation. | ||
* | ||
* BEFORE | ||
* [Node] | ||
* | | ||
* [TransposeConv] (w/ Act) | ||
* | | ||
* [Node] | ||
* | ||
* AFTER | ||
* | ||
* [Node] | ||
* | | ||
* [TransposeConv] | ||
* | | ||
* [ReLU/ReLU6/...] | ||
* | | ||
* [Node] | ||
* | ||
*/ | ||
|
||
namespace | ||
{ | ||
|
||
bool separate_activation_fuction(luci::CircleTransposeConv *trconv) | ||
{ | ||
auto fused_act = trconv->fusedActivationFunction(); | ||
if (fused_act == luci::FusedActFunc::NONE) | ||
return false; | ||
if (fused_act == luci::FusedActFunc::UNDEFINED) | ||
throw std::runtime_error("XpSepActFromTransposeConvPass Activation is undefined"); | ||
|
||
// NOTE features() is call after replace().with(); | ||
// calling loco::replace(trconv).with(actnode) will also update actnode | ||
// itself which will make totally wrong result with actnode input being | ||
// itself. this happends as TransposeConv is re-used, not replaced with | ||
// a new one. | ||
|
||
auto name = trconv->name(); | ||
luci::CircleNode *actnode = nullptr; | ||
switch (fused_act) | ||
{ | ||
case luci::FusedActFunc::RELU: | ||
{ | ||
auto af = trconv->graph()->nodes()->create<luci::CircleRelu>(); | ||
loco::replace(trconv).with(af); | ||
af->features(trconv); | ||
af->name(name + "/Relu"); | ||
actnode = af; | ||
} | ||
break; | ||
case luci::FusedActFunc::RELU6: | ||
{ | ||
auto af = trconv->graph()->nodes()->create<luci::CircleRelu6>(); | ||
loco::replace(trconv).with(af); | ||
af->features(trconv); | ||
af->name(name + "/Relu6"); | ||
actnode = af; | ||
} | ||
break; | ||
// TODO support more | ||
default: | ||
return false; | ||
} | ||
assert(actnode != nullptr); | ||
actnode->dtype(trconv->dtype()); | ||
luci::add_origin(actnode, luci::get_origin(trconv)); | ||
|
||
trconv->fusedActivationFunction(luci::FusedActFunc::NONE); | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
bool XpSepActFromTransposeConvPass::run(loco::Graph *g) | ||
{ | ||
bool changed = false; | ||
for (auto node : loco::active_nodes(loco::output_nodes(g))) | ||
{ | ||
auto trconv = dynamic_cast<luci::CircleTransposeConv *>(node); | ||
if (trconv != nullptr) | ||
{ | ||
if (separate_activation_fuction(trconv)) | ||
changed = true; | ||
} | ||
} | ||
|
||
return changed; | ||
} | ||
|
||
} // namespace luci |
150 changes: 150 additions & 0 deletions
150
compiler/luci/pass/src/XpSepActFromTransposeConvPass.test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/XpSepActFromTransposeConvPass.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
|
||
#include <luci/test/TestIOGraph.h> | ||
#include "test/TestFirstNode.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace | ||
{ | ||
|
||
using namespace luci::test; | ||
|
||
class TrConvGraphlet | ||
{ | ||
public: | ||
TrConvGraphlet() = default; | ||
|
||
public: | ||
void init(loco::Graph *g) | ||
{ | ||
ShapeU32 wshape = {1, 4, 4, 3}; | ||
const uint32_t elements_num = num_elements(wshape); | ||
|
||
// trconv inputSizes | ||
auto wshape_size = static_cast<uint32_t>(wshape.size()); | ||
_inpsize = g->nodes()->create<luci::CircleConst>(); | ||
_inpsize->dtype(loco::DataType::S32); | ||
_inpsize->shape({wshape_size}); | ||
_inpsize->size<loco::DataType::S32>(wshape_size); | ||
auto wsp = wshape.begin(); | ||
for (uint32_t idx = 0; idx < 4; idx++) | ||
{ | ||
_inpsize->at<loco::DataType::S32>(idx) = int32_t(*wsp++); | ||
} | ||
_inpsize->name("inpsize"); | ||
|
||
// trconv filter | ||
_filter = g->nodes()->create<luci::CircleConst>(); | ||
_filter->dtype(loco::DataType::FLOAT32); | ||
_filter->shape(wshape); | ||
_filter->size<loco::DataType::FLOAT32>(elements_num); | ||
for (uint32_t idx = 0; idx < elements_num; idx++) | ||
{ | ||
_filter->at<loco::DataType::FLOAT32>(idx) = float(idx); | ||
} | ||
_filter->name("filter"); | ||
|
||
// trconv | ||
_tc = g->nodes()->create<luci::CircleTransposeConv>(); | ||
_tc->dtype(loco::DataType::FLOAT32); | ||
_tc->name("trconv"); | ||
} | ||
|
||
protected: | ||
luci::CircleTransposeConv *_tc = nullptr; | ||
luci::CircleConst *_filter = nullptr; | ||
luci::CircleConst *_inpsize = nullptr; | ||
}; | ||
|
||
class TrConvGraph : public TestIGraphlet, public TestOGraphlet, public TrConvGraphlet | ||
{ | ||
public: | ||
TrConvGraph() = default; | ||
|
||
void init(const ShapeU32 shape) | ||
{ | ||
TestIGraphlet::init(g(), shape); | ||
TestOGraphlet::init(g(), shape); | ||
TrConvGraphlet::init(g()); | ||
|
||
// connect graph | ||
_tc->inputSizes(_inpsize); | ||
_tc->filter(_filter); | ||
_tc->outBackprop(input()); | ||
|
||
output()->from(_tc); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
TEST(XpSepActFromTransposeConvPassTest, name) | ||
{ | ||
luci::XpSepActFromTransposeConvPass pass; | ||
auto const name = pass.name(); | ||
ASSERT_NE(nullptr, name); | ||
} | ||
|
||
TEST(XpSepActFromTransposeConvPassTest, simple_normal) | ||
{ | ||
TrConvGraph g; | ||
|
||
g.init({1, 4, 4, 3}); | ||
|
||
auto tc_node = luci::test::first_node<luci::CircleTransposeConv>(g.g()); | ||
ASSERT_NE(tc_node, nullptr); | ||
tc_node->fusedActivationFunction(luci::FusedActFunc::RELU); | ||
|
||
luci::XpSepActFromTransposeConvPass pass; | ||
EXPECT_EQ(pass.run(g.g()), true); | ||
|
||
auto la_node = dynamic_cast<luci::CircleRelu *>(g.output()->from()); | ||
ASSERT_NE(la_node, nullptr); | ||
} | ||
|
||
TEST(XpSepActFromTransposeConvPassTest, none_act_NEG) | ||
{ | ||
TrConvGraph g; | ||
|
||
g.init({1, 4, 4, 3}); | ||
|
||
auto tc_node = luci::test::first_node<luci::CircleTransposeConv>(g.g()); | ||
ASSERT_NE(tc_node, nullptr); | ||
tc_node->fusedActivationFunction(luci::FusedActFunc::NONE); | ||
|
||
luci::XpSepActFromTransposeConvPass pass; | ||
EXPECT_NE(pass.run(g.g()), true); | ||
} | ||
|
||
TEST(XpSepActFromTransposeConvPassTest, invalid_act_NEG) | ||
{ | ||
TrConvGraph g; | ||
|
||
g.init({1, 4, 4, 3}); | ||
|
||
auto tc_node = luci::test::first_node<luci::CircleTransposeConv>(g.g()); | ||
ASSERT_NE(tc_node, nullptr); | ||
// leave activation as undefined | ||
|
||
luci::XpSepActFromTransposeConvPass pass; | ||
EXPECT_ANY_THROW(pass.run(g.g())); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.