diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 1ef947a5111..1d698213915 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -136,7 +136,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleShape *node) final; // loco::TensorShape visit(const luci::CircleSin *node) final; // loco::TensorShape visit(const luci::CircleSlice *node) final; - // loco::TensorShape visit(const luci::CircleSoftmax *node) final; + loco::TensorShape visit(const luci::CircleSoftmax *node) final; // loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final; // loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final; // loco::TensorShape visit(const luci::CircleSparseToDense *node) final; diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 0d698e0130b..3aea8d0501e 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -2358,8 +2358,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor::visit(const luci::CircleSoftmax *node) return cloned; } +loco::TensorShape sinf::Algorithm::visit(const luci::CircleSoftmax *node) +{ + const auto logits = loco::must_cast(node->logits()); + + return sinf::circle_shape(logits); +} + } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp index c80b44d69b5..8ccaeb8784b 100644 --- a/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp @@ -16,6 +16,8 @@ #include "luci/Service/CircleNodeClone.h" +#include + #include TEST(CloneNodeTest, clone_Softmax) @@ -33,3 +35,79 @@ TEST(CloneNodeTest, clone_Softmax) ASSERT_NE(nullptr, cloned_sm); ASSERT_EQ(node_sm->beta(), cloned_sm->beta()); } + +TEST(ShapeRuleTest, softmax_static_shape) +{ + luci::CircleInput input; + luci::CircleSoftmax softmax; + + input.shape({1, 4, 3, 8}); + input.shape_status(luci::ShapeStatus::VALID); + + softmax.logits(&input); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&softmax, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_TRUE(shape.dim(0).known()); + ASSERT_TRUE(shape.dim(1).known()); + ASSERT_TRUE(shape.dim(2).known()); + ASSERT_TRUE(shape.dim(3).known()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(4, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); + ASSERT_EQ(8, shape.dim(3).value()); +} + +TEST(ShapeRuleTest, softmax_dynamic_shape) +{ + luci::CircleInput input; + luci::CircleSoftmax softmax; + + input.shape({1, 4, 3, 8}); + input.shape_status(luci::ShapeStatus::VALID); + input.dim(1).unset(); + + softmax.logits(&input); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&softmax, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_TRUE(shape.dim(0).known()); + ASSERT_FALSE(shape.dim(1).known()); + ASSERT_TRUE(shape.dim(2).known()); + ASSERT_TRUE(shape.dim(3).known()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(0, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); + ASSERT_EQ(8, shape.dim(3).value()); +} + +TEST(ShapeRuleTest, softmax_wrong_input_NEG) +{ + luci::CircleSoftmax softmax; + + softmax.logits(nullptr); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(&softmax, shape)); +} + +TEST(ShapeRuleTest, softmax_wrong_input_2_NEG) +{ + luci::CircleInput *input = nullptr; + luci::CircleSoftmax softmax; + + softmax.logits(input); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(&softmax, shape)); +}