Skip to content

Commit

Permalink
[luci/svc] Handle 0 in shape for Reshape Op shape inference (#14144)
Browse files Browse the repository at this point in the history
This will fix to handle 0 in shape for Reshape Op shape inference.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Oct 2, 2024
1 parent cec4b7b commit 8672617
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <luci/Log.h>

#include <oops/InternalExn.h>

namespace
{

Expand Down Expand Up @@ -88,11 +90,29 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
if (const_shape_node->at<S32>(axis) < 0)
{
shape_by_input.dim(axis).unset();
}
else if (const_shape_node->at<S32>(axis) == 0)
{
const auto node_tensor = loco::must_cast<luci::CircleNode *>(node->tensor());
// set dim value to input
if (node_tensor->shape_status() == luci::ShapeStatus::VALID && axis < node_tensor->rank())
shape_by_input.dim(axis) = node_tensor->dim(axis);
else
{
// stop to check if this case exist for debugging
INTERNAL_EXN("Check Reshape shape with 0");
}
}
else
{
shape_by_input.dim(axis).set(const_shape_node->at<S32>(axis));
}
// check valid or stop for debugging
LUCI_ASSERT(shape_by_input.dim(axis).value() > 0 || !shape_by_input.dim(axis).known(),
"Reshape infer shape is invalid.");
}
}
else
Expand Down Expand Up @@ -143,14 +163,26 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
uint32_t dim_value = output_shape.dim(dim_index).value();
if (not output_shape.dim(dim_index).known())
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
if (!dim_value)
{
// refer https://github.com/Samsung/ONE/issues/14074#issuecomment-2370795003
// set dim value to follow input
if (dim_index < input_shape.rank())
dim_value = input_shape.dim(dim_index).value();
else
{
// stop to check if this case exist for debugging
INTERNAL_EXN("Check Reshape shape with 0");
}
}
output_element_count *= dim_value;
}
}
Expand Down

0 comments on commit 8672617

Please sign in to comment.