Skip to content

Commit

Permalink
[luci] Infer dynamic shape for concat (#13698)
Browse files Browse the repository at this point in the history
This infers dynamic shape for concat Op.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Aug 21, 2024
1 parent 7d46167 commit b22860f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 6 deletions.
33 changes: 27 additions & 6 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,35 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
{
if (j == static_cast<uint32_t>(axis))
{
// If dimension is unknown, value() will return 0.
// This is wrong but until new inference algorithm is implemented,
// this code will not be modified to keep compatibility.
output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
if (output_shape.dim(j).known() and input_shape.dim(j).known())
{
output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
}
else
{
// If any of inputs is unknown, just mark it as unknown.
output_shape.dim(j).unset();
}
}
else
assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() ||
output_shape.dim(j) == input_shape.dim(j));
{
if (output_shape.dim(j).known() and input_shape.dim(j).known())
{
if (output_shape.dim(j).value() != input_shape.dim(j).value())
{
INTERNAL_EXN_V("Input has incompatible shape.", node->name());
}
}
else
{
if (input_shape.dim(j).known())
{
assert(not output_shape.dim(j).known()); // FIX_ME_UNLESS
output_shape.dim(j) = input_shape.dim(j);
}
// For unknown input_shape, leave output_shape as-is
}
}
}
}

Expand Down
111 changes: 111 additions & 0 deletions compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Service/CircleNodeClone.h"

#include <luci/Service/CircleShapeInference.h>

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_Concatenation)
Expand Down Expand Up @@ -47,3 +49,112 @@ TEST(CloneNodeTest, clone_Concatenation_NEG)
auto cloned = luci::clone_node(node_concat, gc.get());
ASSERT_EQ(nullptr, cloned);
}

TEST(ShapeRuleTest, concat_dynamic_shape_axis)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleConcatenation concat(2);

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 4, 3, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

concat.values(0, &input_1);
concat.values(1, &input_2);
concat.axis(2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&concat, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_FALSE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(0, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, concat_dynamic_shape_non_axis)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleConcatenation concat(2);

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 4, 3, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

concat.values(0, &input_1);
concat.values(1, &input_2);
concat.axis(1);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&concat, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(8, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, concat_wrong_shape_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleConcatenation concat(2);

input_1.shape({1, 4, 4, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 4, 3, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

concat.values(0, &input_1);
concat.values(1, &input_2);
concat.axis(1);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

EXPECT_ANY_THROW(shape_inf_rule.infer(&concat, shape));
}

TEST(ShapeRuleTest, concat_rank_mismatch_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleConcatenation concat(2);

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 4, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

concat.values(0, &input_1);
concat.values(1, &input_2);
concat.axis(2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

EXPECT_ANY_THROW(shape_inf_rule.infer(&concat, shape));
}

0 comments on commit b22860f

Please sign in to comment.