Skip to content

Commit

Permalink
[luci/service] Support SOFTMAX dynamic shape inference (#13784)
Browse files Browse the repository at this point in the history
* [luci/service] Support SOFTMAX dynamic shape inference

This supports dynamic shape inference for SOFTMAX Op.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>

* Adjust headers
  • Loading branch information
jinevening authored Aug 27, 2024
1 parent 7510752 commit feef973
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleShape *node) final;
// loco::TensorShape visit(const luci::CircleSin *node) final;
// loco::TensorShape visit(const luci::CircleSlice *node) final;
// loco::TensorShape visit(const luci::CircleSoftmax *node) final;
loco::TensorShape visit(const luci::CircleSoftmax *node) final;
// loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final;
// loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final;
// loco::TensorShape visit(const luci::CircleSparseToDense *node) final;
Expand Down
2 changes: 0 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2358,8 +2358,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CircleSlice *node) final { return infer_slice(node); }

loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }

loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
{
return infer_space_to_batch_nd(node);
Expand Down
10 changes: 10 additions & 0 deletions compiler/luci/service/src/Nodes/CircleSoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
* limitations under the License.
*/

#include "luci/Service/CircleShapeInference.h"

#include "CircleCloneNode.h"
#include "CircleShapeInferenceHelper.h"

namespace luci
{
Expand All @@ -27,4 +30,11 @@ luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSoftmax *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleSoftmax *node)
{
const auto logits = loco::must_cast<luci::CircleNode *>(node->logits());

return sinf::circle_shape(logits);
}

} // namespace luci
78 changes: 78 additions & 0 deletions compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "luci/Service/CircleNodeClone.h"

#include <luci/Service/CircleShapeInference.h>

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_Softmax)
Expand All @@ -33,3 +35,79 @@ TEST(CloneNodeTest, clone_Softmax)
ASSERT_NE(nullptr, cloned_sm);
ASSERT_EQ(node_sm->beta(), cloned_sm->beta());
}

TEST(ShapeRuleTest, softmax_static_shape)
{
luci::CircleInput input;
luci::CircleSoftmax softmax;

input.shape({1, 4, 3, 8});
input.shape_status(luci::ShapeStatus::VALID);

softmax.logits(&input);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&softmax, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_TRUE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(4, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(8, shape.dim(3).value());
}

TEST(ShapeRuleTest, softmax_dynamic_shape)
{
luci::CircleInput input;
luci::CircleSoftmax softmax;

input.shape({1, 4, 3, 8});
input.shape_status(luci::ShapeStatus::VALID);
input.dim(1).unset();

softmax.logits(&input);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&softmax, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_FALSE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(0, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(8, shape.dim(3).value());
}

TEST(ShapeRuleTest, softmax_wrong_input_NEG)
{
luci::CircleSoftmax softmax;

softmax.logits(nullptr);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&softmax, shape));
}

TEST(ShapeRuleTest, softmax_wrong_input_2_NEG)
{
luci::CircleInput *input = nullptr;
luci::CircleSoftmax softmax;

softmax.logits(input);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&softmax, shape));
}

0 comments on commit feef973

Please sign in to comment.