Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][luci/service] Migrate helperPads to ShapeInferenceHelper #13848

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions compiler/luci/service/src/CircleShapeInferenceHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#include "CircleShapeInferenceHelper.h"

#include "Check.h"

#include <limits>

#include <oops/InternalExn.h>

using namespace luci::sinf;
Expand Down Expand Up @@ -157,5 +161,57 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor
return output_shape;
}

loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings)
{
const loco::DataType S32 = loco::DataType::S32;
const loco::DataType S64 = loco::DataType::S64;

// TODO support other data type
LUCI_ASSERT(paddings->dtype() == S32 || paddings->dtype() == S64, "Support int 32/64 for now");
LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2");

int32_t n = paddings->dim(0).value();
int32_t v = paddings->dim(1).value();

LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
LUCI_ASSERT(n == int32_t(input_shape.rank()),
"paddings [n, 2] should have same value of input rank");

loco::TensorShape output_shape;

output_shape.rank(input_shape.rank());
for (int32_t ni = 0; ni < n; ++ni)
{
if (not input_shape.dim(ni).known())
{
output_shape.dim(ni).unset();
continue;
}
int32_t idx = ni * 2;
int value = input_shape.dim(ni).value();
if (paddings->dtype() == S32)
{
value += paddings->at<S32>(idx + 0); // left
value += paddings->at<S32>(idx + 1); // right
}
else
{
auto pl = paddings->at<S64>(idx + 0);
auto pr = paddings->at<S64>(idx + 1);
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pl >= low, "paddings is over 32 bit limit");
LUCI_ASSERT(pr <= max, "paddings is over 32 bit limit");
LUCI_ASSERT(pr >= low, "paddings is over 32 bit limit");
value += static_cast<int32_t>(pl); // left
value += static_cast<int32_t>(pr); // right
}
output_shape.dim(ni) = value;
}

return output_shape;
}

} // namespace sinf
} // namespace luci
4 changes: 4 additions & 0 deletions compiler/luci/service/src/CircleShapeInferenceHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ loco::TensorShape circle_shape(const luci::CircleNode *node);
// Throw an exception if x and y are not broadcastable.
loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y);

// Return shape of pad ops using paddings.
loco::TensorShape pad_shape(const loco::TensorShape &input_shape,
const luci::CircleConst *paddings);

/**
* @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
*
Expand Down
88 changes: 0 additions & 88 deletions compiler/luci/service/src/HelperPads.h

This file was deleted.

4 changes: 2 additions & 2 deletions compiler/luci/service/src/Nodes/CirclePad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include "CircleCloneNode.h"
#include "CircleShapeInferenceHelper.h"
#include "HelperPads.h"

namespace luci
{
Expand All @@ -35,7 +34,8 @@ loco::TensorShape Algorithm::visit(const luci::CirclePad *node)
{
// TODO support non-const case
auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
return use_paddings(node, paddings);
auto input_shape = circle_shape(loco::must_cast<const luci::CircleNode *>(node->input()));
icodo98 marked this conversation as resolved.
Show resolved Hide resolved
return pad_shape(input_shape, paddings);
}

} // namespace sinf
Expand Down