diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 877bd5a0ae7..4024b8d51bc 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -237,25 +237,33 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa { int32_t idx = ni * 2; int value = input_shape.dim(ni).value(); - if (paddings->dtype() == S32) + if (!input_shape.dim(ni).known()) { - value += paddings->at(idx + 0); // left - value += paddings->at(idx + 1); // right + output_shape.dim(ni).unset(); } else { - auto pl = paddings->at(idx + 0); - auto pr = paddings->at(idx + 1); - auto max = static_cast(std::numeric_limits::max()); - auto low = static_cast(std::numeric_limits::lowest()); - LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit"); - LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit"); - LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit"); - LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit"); - value += static_cast(pl); // left - value += static_cast(pr); // right + + if (paddings->dtype() == S32) + { + value += paddings->at(idx + 0); // left + value += paddings->at(idx + 1); // right + } + else + { + auto pl = paddings->at(idx + 0); + auto pr = paddings->at(idx + 1); + auto max = static_cast(std::numeric_limits::max()); + auto low = static_cast(std::numeric_limits::lowest()); + LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit"); + LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit"); + LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit"); + LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit"); + value += static_cast(pl); // left + value += static_cast(pr); // right + } + output_shape.dim(ni) = value; } - output_shape.dim(ni) = value; } return loco::NodeShape{output_shape}; diff --git a/compiler/luci/service/src/Nodes/CirclePad.test.cpp b/compiler/luci/service/src/Nodes/CirclePad.test.cpp index 1d5f8375e1f..8e18d2b20ee 100644 --- a/compiler/luci/service/src/Nodes/CirclePad.test.cpp +++ b/compiler/luci/service/src/Nodes/CirclePad.test.cpp @@ -16,6 +16,8 @@ #include "luci/Service/CircleNodeClone.h" +#include + #include TEST(CloneNodeTest, clone_Pad) @@ -31,3 +33,42 @@ TEST(CloneNodeTest, clone_Pad) auto cloned_pad = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_pad); } + +TEST(ShapeRuleTest, pad_dynamic_shape) +{ + luci::CirclePad pad; + luci::CircleInput input; + // Use circle input as paddings + luci::CircleConst padddings; + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + input.shape({1, 2, 3, 4}); + input.shape_status(luci::ShapeStatus::VALID); + input.dim(2).unset(); + + padddings.dtype(loco::DataType::S64); + padddings.shape({4, 2}); + padddings.shape_status(luci::ShapeStatus::VALID); + + const loco::DataType S64 = loco::DataType::S64; + uint32_t t = 64 * 8; + padddings.size(t); + + pad.input(&input); + pad.paddings(&padddings); + + ASSERT_TRUE(shape_inf_rule.infer(&pad, shape)); + ASSERT_EQ(shape.rank(), 4); + ASSERT_TRUE(shape.dim(0).known()); + ASSERT_TRUE(shape.dim(1).known()); + ASSERT_FALSE(shape.dim(2).known()); + ASSERT_TRUE(shape.dim(3).known()); + + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(2, shape.dim(1).value()); + ASSERT_EQ(0, shape.dim(2).value()); + ASSERT_EQ(4, shape.dim(3).value()); + pad.drop(); +}