Skip to content

Commit

Permalink
DRAFT: Support MUL dynamic shape inference
Browse files Browse the repository at this point in the history
On going draft to support dynamic shape inference for MUL.

ONE-DCO-1.0-Signed-off-by: sunki <[email protected]>
  • Loading branch information
qsunki committed Aug 27, 2024
1 parent 96b5c2e commit 94a3b1f
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleMean *node) final;
// loco::TensorShape visit(const luci::CircleMinimum *node) final;
// loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
// loco::TensorShape visit(const luci::CircleMul *node) final;
loco::TensorShape visit(const luci::CircleMul *node) final;
// loco::TensorShape visit(const luci::CircleNeg *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
Expand Down
2 changes: 0 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2205,8 +2205,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }

loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }

loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }

loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
Expand Down
15 changes: 15 additions & 0 deletions compiler/luci/service/src/Nodes/CircleMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
* limitations under the License.
*/

#include <luci/Service/CircleShapeInference.h>

#include "CircleCloneNode.h"

#include "CircleShapeInferenceHelper.h"

namespace luci
{

Expand All @@ -30,4 +34,15 @@ luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMul *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleMul *node)
{
const auto x = loco::must_cast<luci::CircleNode *>(node->x());
const auto y = loco::must_cast<luci::CircleNode *>(node->y());

const auto x_shape = sinf::circle_shape(x);
const auto y_shape = sinf::circle_shape(y);

return broadcast_shape(x_shape, y_shape);
}

} // namespace luci
261 changes: 261 additions & 0 deletions compiler/luci/service/src/Nodes/CircleMul.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Service/CircleNodeClone.h"

#include <luci/Service/CircleShapeInference.h>

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_Mul)
Expand Down Expand Up @@ -44,3 +46,262 @@ TEST(CloneNodeTest, clone_Mul_NEG)
auto cloned = luci::clone_node(node_mul, gc.get());
ASSERT_EQ(nullptr, cloned);
}

TEST(ShapeRuleTest, mul_known_dim)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&mul, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, mul_dynamic_shape_non_1)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(0).unset();

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&mul, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, mul_dynamic_shape_1)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&mul, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

TEST(ShapeRuleTest, mul_dynamic_shape_both)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);
input_1.dim(3).unset();

input_2.shape({1, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);
input_2.dim(2).unset();

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&mul, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

TEST(ShapeRuleTest, mul_scalar)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&mul, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(1, shape.dim(3).value());
}

TEST(ShapeRuleTest, mul_not_broadcastable_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({1, 2, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&mul, shape));
}

TEST(ShapeRuleTest, mul_not_broadcastable_2_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 1, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&mul, shape));
}

TEST(ShapeRuleTest, mul_not_broadcastable_3_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 3, 1});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&mul, shape));
}

TEST(ShapeRuleTest, mul_not_broadcastable_4_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({2, 3, 2});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&mul, shape));
}

TEST(ShapeRuleTest, mul_not_broadcastable_5_NEG)
{
luci::CircleInput input_1;
luci::CircleInput input_2;
luci::CircleMul mul;

input_1.shape({1, 4, 3, 1});
input_1.shape_status(luci::ShapeStatus::VALID);

input_2.shape({3, 2, 3, 2});
input_2.shape_status(luci::ShapeStatus::VALID);

mul.x(&input_1);
mul.y(&input_2);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&mul, shape));
}

0 comments on commit 94a3b1f

Please sign in to comment.