Skip to content

Commit

Permalink
WIP: [luci/service] Support BMM dynamic shape inferece
Browse files Browse the repository at this point in the history
This PR supports dynamic shpae inference ofr BatchMatMul Op.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite committed Aug 26, 2024
1 parent 0d4cf4c commit f6306dc
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleArgMax *node) final;
// loco::TensorShape visit(const luci::CircleArgMin *node) final;
// loco::TensorShape visit(const luci::CircleAveragePool2D *node) final;
// loco::TensorShape visit(const luci::CircleBatchMatMul *node) final;
loco::TensorShape visit(const luci::CircleBatchMatMul *node) final;
// loco::TensorShape visit(const luci::CircleBatchToSpaceND *node) final;
// loco::TensorShape visit(const luci::CircleCast *node) final;
// loco::TensorShape visit(const luci::CircleCeil *node) final;
Expand Down
67 changes: 67 additions & 0 deletions compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@

#include "CircleCloneNode.h"

#include "CircleShapeInferenceHelper.h"

namespace
{

loco::TensorShape remove_last_two(const loco::TensorShape &original_shape)
{
assert(original_shape.rank() >= 2); // FIX CALLER UNLESS

loco::TensorShape ret;
ret.rank(original_shape.rank() - 2);

for (int i = 0; i < ret.rank(); ++i)
{
ret.dim(i) = original_shape.dim(i);
}
return ret;
}

} // namespace

namespace luci
{

Expand All @@ -30,4 +51,50 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleBatchMatMul *no
return cloned;
}

// BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
// TODO Distinguish BatchMatMul and BatchMatMulV2
loco::TensorShape sinf::Algorithm::visit(const luci::CircleBatchMatmul *node)
{
const auto x_shape = sinf::circle_shape(loco::must_cast<luci::CircleNode>(node));
const auto y_shape = sinf::circle_shape(loco::must_cast<luci::CircleNode>(node));

uint32_t x_rank = x_shape.rank();
uint32_t y_rank = y_shape.rank();
assert(x_rank >= 2 && y_rank >= 2);

uint32_t max_rank = x_rank > y_rank ? x_rank : y_rank;
loco::TensorShape output_shape;
output_shape.rank(max_rank);

// broadcast in the batch dimensions
if (x_rank > 2 || y_rank > 2)
{
const auto x_batch_dims = remove_last_two(x_shape);
const auto y_batch_dims = remove_last_two(y_shape);

const auto o_batch_dims = sinf::broadcast_shape(x_batch_dims, y_batch_dims);

const auto o_batch_rank = o_batch_dims.rank();
for (int i = 0; i < o_batch_rank; ++i)
{
output_shape.dim(i) = o_batch_dims.dim(i);
}
}

loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);

// TODO : Add logic for dynamic x_{lhs, rhs} and y_{lhs, rhs}
if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs))
INTERNAL_EXN("x_rhs and y_lhs should be same");

uint32_t out_rank = output_shape.rank();
output_shape.dim(out_rank - 2) = x_lhs;
output_shape.dim(out_rank - 1) = y_rhs;

return loco::NodeShape{output_shape};
}

} // namespace luci
65 changes: 65 additions & 0 deletions compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Service/CircleNodeClone.h"

#include "luci/Service/CircleTypeInference.h"

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_BatchMatMul)
Expand All @@ -35,3 +37,66 @@ TEST(CloneNodeTest, clone_BatchMatMul)
ASSERT_EQ(node_bmm->adj_x(), cloned_bmm->adj_x());
ASSERT_EQ(node_bmm->adj_y(), cloned_bmm->adj_y());
}

TEST(ShapeRuleTest, div_known_dim)
{
luci::CircleInput input_x;
luci::CircleInput input_y;
luci::CircleBatchMatMul bmm;

input_x.shape({2, 4, 3});
input_x.shape_status(luci::ShapeStatus::VALID);

input_y.shape({1, 3, 5});
input_y.shape_status(luci::ShapeStatus::VALID);

bmm.x(&input_x);
bmm.y(&input_y);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&bmm, shape));

// Expected shape is [2, 3, 5]
ASSERT_EQ(3, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_EQ(2, shape.dim(0).value());
ASSERT_EQ(3, shape.dim(1).value());
ASSERT_EQ(5, shape.dim(2).value());
}

TEST(ShapeRuleTest, div_known_dim)
{
luci::CircleInput input_x;
luci::CircleInput input_y;
luci::CircleBatchMatMul bmm;

input_x.shape({0, 4, 3});
input_x.shape_status(luci::ShapeStatus::VALID);
input_x.dim(0).unset();

input_y.shape({2, 5, 3, 7});
input_y.shape_status(luci::ShapeStatus::VALID);

bmm.x(&input_x);
bmm.y(&input_y);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&bmm, shape));

// Expected shape is [2, 5, 4, 7]
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(2, shape.dim(0).value());
ASSERT_EQ(5, shape.dim(1).value());
ASSERT_EQ(4, shape.dim(2).value());
ASSERT_EQ(7, shape.dim(3).value());
}

0 comments on commit f6306dc

Please sign in to comment.