Skip to content

Commit

Permalink
DRAFT: Support ADD dynamic shape inference
Browse files Browse the repository at this point in the history
On goging draft to support dynamic shape inference for ADD.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening committed Aug 26, 2024
1 parent 6e70506 commit ab2c557
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
}

// loco::TensorShape visit(const luci::CircleAbs *node) final;
// loco::TensorShape visit(const luci::CircleAdd *node) final;
loco::TensorShape visit(const luci::CircleAdd *node) final;
// loco::TensorShape visit(const luci::CircleAddN *node) final;
// loco::TensorShape visit(const luci::CircleArgMax *node) final;
// loco::TensorShape visit(const luci::CircleArgMin *node) final;
Expand Down
2 changes: 0 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2019,8 +2019,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS
public:
loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }

loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }

loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }

loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_maxmin(node); }
Expand Down
15 changes: 15 additions & 0 deletions compiler/luci/service/src/Nodes/CircleAdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
* limitations under the License.
*/

#include <luci/Service/CircleShapeInference.h>

#include "CircleShapeInferenceHelper.h"

#include "CircleCloneNode.h"

namespace luci
Expand All @@ -30,4 +34,15 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAdd *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleAdd *node)
{
const auto x = loco::must_cast<luci::CircleNode *>(node->x());
const auto y = loco::must_cast<luci::CircleNode *>(node->y());

const auto x_shape = sinf::circle_shape(x);
const auto y_shape = sinf::circle_shape(y);

return broadcast_shape(x_shape, y_shape);
}

} // namespace luci
259 changes: 259 additions & 0 deletions compiler/luci/service/src/Nodes/CircleAdd.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,262 @@ TEST(CloneNodeTest, clone_Add_NEG)
auto cloned = luci::clone_node(node_add, gc.get());
ASSERT_EQ(nullptr, cloned);
}

TEST(ShapeRuleTest, add_known_dim)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, add_dynamic_shape_non_1)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(0).unset();

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, add_dynamic_shape_1)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

TEST(ShapeRuleTest, add_dynamic_shape_both)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);
input_1.dim(3).unset();

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

TEST(ShapeRuleTest, add_scalar)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, add_not_broadcastable_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 2, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&add, shape));
}

TEST(ShapeRuleTest, add_not_broadcastable_2_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&add, shape));
}

TEST(ShapeRuleTest, add_not_broadcastable_3_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 3, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&add, shape));
}

TEST(ShapeRuleTest, add_not_broadcastable_4_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 3, 2});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&add, shape));
}

TEST(ShapeRuleTest, add_not_broadcastable_5_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleAdd add;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({3, 2, 3, 2});
input_2.shape_status(luci::ShapeStatus::VALID);

add.x(&input_1);
add.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&add, shape));
}
4 changes: 0 additions & 4 deletions compiler/luci/tests/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ addread(Abs_000)
addread(Add_000)
addread(Add_001)
addread(Add_U8_000)
addread(Add_STR_000)
addread(Add_STR_001)
addread(AddN_000)
addread(ArgMax_000)
addread(ArgMax_001)
Expand Down Expand Up @@ -235,8 +233,6 @@ addwrite(Abs_000)
addwrite(Add_000)
addwrite(Add_001)
addwrite(Add_U8_000)
addwrite(Add_STR_000)
addwrite(Add_STR_001)
addwrite(AddN_000)
addwrite(ArgMax_000)
addwrite(ArgMax_001)
Expand Down
33 changes: 0 additions & 33 deletions res/TensorFlowLiteRecipes/Add_STR_000/test.recipe

This file was deleted.

34 changes: 0 additions & 34 deletions res/TensorFlowLiteRecipes/Add_STR_001/test.recipe

This file was deleted.

Loading

0 comments on commit ab2c557

Please sign in to comment.