Skip to content

Commit

Permalink
[luci/pass] Canonicalize PadV2 paddings (#13542)
Browse files Browse the repository at this point in the history
This will enable Canonicalize PadV2 paddings dtype.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Jul 29, 2024
1 parent f41dcf9 commit 19b0b2b
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
50 changes: 50 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,51 @@ bool paddings_to_s32(luci::CirclePad *pad)
return true;
}

/**
* Convert S64 CircleConst paddings to S32
*/
bool paddings_to_s32(luci::CirclePadV2 *padv2)
{
// check conditions
auto paddings = dynamic_cast<luci::CircleConst *>(padv2->paddings());
CHECK_OR_FALSE(paddings);
CHECK_OR_FALSE(paddings->dtype() == loco::DataType::S64);

// TODO relocate to helpers/CreateCircleConst.h when necessary
auto num_elements = paddings->size<loco::DataType::S64>();
auto hval = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto lval = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
CHECK_OR_FALSE(v64 <= hval);
CHECK_OR_FALSE(v64 >= lval);
}

auto paddings_s32 = padv2->graph()->nodes()->create<luci::CircleConst>();
paddings_s32->name(paddings->name() + "_S32");
paddings_s32->dtype(loco::DataType::S32);
paddings_s32->rank(paddings->rank());
for (uint32_t i = 0; i < paddings->rank(); i++)
paddings_s32->dim(i).set(paddings->dim(i).value());
paddings_s32->shape_status(luci::ShapeStatus::VALID);
luci::add_origin(paddings_s32, luci::get_origin(paddings));

paddings_s32->size<loco::DataType::S32>(num_elements);
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
paddings_s32->at<loco::DataType::S32>(i) = static_cast<int32_t>(v64);
}

// replace paddings with S32 dtype
padv2->paddings(paddings_s32);

return true;
}

// TODO merge both paddings_to_s32 with template

} // namespace

namespace luci
Expand All @@ -91,6 +136,11 @@ bool CanonicalizePass::run(loco::Graph *g)
if (paddings_to_s32(pad))
changed = true;
}
else if (auto padv2 = dynamic_cast<luci::CirclePadV2 *>(node))
{
if (paddings_to_s32(padv2))
changed = true;
}

// TODO add more canonicalization
}
Expand Down
74 changes: 74 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ struct PadGraphlet
void init(loco::Graph *g)
{
_pad = g->nodes()->create<luci::CirclePad>();
_padv2 = g->nodes()->create<luci::CirclePadV2>();
_paddings_s32 = g->nodes()->create<luci::CircleConst>();
_paddings_s64 = g->nodes()->create<luci::CircleConst>();
// NOTE PadV2.constant_values is not set as test doesn't use this

_pad->name("pad");
_padv2->name("padv2");
_paddings_s32->name("paddings_s32");
_paddings_s64->name("paddings_s64");

Expand Down Expand Up @@ -75,6 +78,7 @@ struct PadGraphlet
}

luci::CirclePad *_pad = nullptr;
luci::CirclePadV2 *_padv2 = nullptr;
luci::CircleConst *_paddings_s32 = nullptr;
luci::CircleConst *_paddings_s64 = nullptr;
};
Expand All @@ -96,6 +100,23 @@ class CanonicalizePadTestGraph : public TestIOGraph, public PadGraphlet
}
};

class CanonicalizePadV2TestGraph : public TestIOGraph, public PadGraphlet
{
public:
CanonicalizePadV2TestGraph() = default;

void init(void)
{
TestIOGraph::init({1, 3, 3, 2}, {1, 5, 5, 2});
PadGraphlet::init(g());

_padv2->input(input());
_padv2->paddings(_paddings_s64);

output()->from(_padv2);
}
};

} // namespace

TEST(CanonicalizePassPadTest, paddings_64_to_32)
Expand Down Expand Up @@ -150,3 +171,56 @@ TEST(CanonicalizePassPadTest, paddings_32_over_NEG)
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);
}

TEST(CanonicalizePassPadV2Test, paddings_64_to_32)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);

EXPECT_TRUE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadV2Test, paddings_32_NEG)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();
g._padv2->paddings(g._paddings_s32);

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);

EXPECT_FALSE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadV2Test, paddings_32_over_NEG)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();
g._paddings_s64->at<loco::DataType::S64>(2) =
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 100;

EXPECT_FALSE(pass.run(g.g()));

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);
}

0 comments on commit 19b0b2b

Please sign in to comment.