Skip to content

Commit

Permalink
[luci] infer dynamic shape for pad
Browse files Browse the repository at this point in the history
This infers dynmic shape for pad.
If input shape is unknown, output shape is also unknown.

ONE-DCO-1.0-Signed-off-by: JuYoung Lee <[email protected]>
  • Loading branch information
icodo98 committed Aug 22, 2024
1 parent ac4d588 commit 36cc694
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
36 changes: 22 additions & 14 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,33 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa
{
int32_t idx = ni * 2;
int value = input_shape.dim(ni).value();
if (paddings->dtype() == S32)
if (!input_shape.dim(ni).known())
{
value += paddings->at<S32>(idx + 0); // left
value += paddings->at<S32>(idx + 1); // right
output_shape.dim(ni).unset();
}
else
{
auto pl = paddings->at<S64>(idx + 0);
auto pr = paddings->at<S64>(idx + 1);
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit");
LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit");
value += static_cast<int32_t>(pl); // left
value += static_cast<int32_t>(pr); // right

if (paddings->dtype() == S32)
{
value += paddings->at<S32>(idx + 0); // left
value += paddings->at<S32>(idx + 1); // right
}
else
{
auto pl = paddings->at<S64>(idx + 0);
auto pr = paddings->at<S64>(idx + 1);
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit");
LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit");
value += static_cast<int32_t>(pl); // left
value += static_cast<int32_t>(pr); // right
}
output_shape.dim(ni) = value;
}
output_shape.dim(ni) = value;
}

return loco::NodeShape{output_shape};
Expand Down
41 changes: 41 additions & 0 deletions compiler/luci/service/src/Nodes/CirclePad.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Service/CircleNodeClone.h"

#include <luci/Service/CircleShapeInference.h>

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_Pad)
Expand All @@ -31,3 +33,42 @@ TEST(CloneNodeTest, clone_Pad)
auto cloned_pad = dynamic_cast<luci::CirclePad *>(cloned);
ASSERT_NE(nullptr, cloned_pad);
}

TEST(ShapeRuleTest, pad_dynamic_shape)
{
luci::CirclePad pad;
luci::CircleInput input;
// Use circle input as paddings
luci::CircleConst padddings;

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

input.shape({1, 2, 3, 4});
input.shape_status(luci::ShapeStatus::VALID);
input.dim(2).unset();

padddings.dtype(loco::DataType::S64);
padddings.shape({4, 2});
padddings.shape_status(luci::ShapeStatus::VALID);

const loco::DataType S64 = loco::DataType::S64;
uint32_t t = 64 * 8;
padddings.size<S64>(t);

pad.input(&input);
pad.paddings(&padddings);

ASSERT_TRUE(shape_inf_rule.infer(&pad, shape));
ASSERT_EQ(shape.rank(), 4);
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_FALSE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());

ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(2, shape.dim(1).value());
ASSERT_EQ(0, shape.dim(2).value());
ASSERT_EQ(4, shape.dim(3).value());
pad.drop();
}

0 comments on commit 36cc694

Please sign in to comment.