From eb7236462b7e52b90ddcd81ef784d0d54774f239 Mon Sep 17 00:00:00 2001 From: youngsik kim Date: Fri, 27 Sep 2024 09:17:52 +0900 Subject: [PATCH] [luci/service] Support RoPE operation (#14091) This commit supports RoPE for luci service ONE-DCO-1.0-Signed-off-by: youngsik kim ys44.kim@samsung.com --- .../luci/Service/CircleShapeInference.h | 1 + .../luci/Service/CircleTypeInference.h | 1 + compiler/luci/service/src/CircleCloneNode.h | 1 + .../service/src/CircleShapeInferenceRule.cpp | 7 +++ .../service/src/CircleTypeInferenceRule.cpp | 5 ++ .../luci/service/src/Nodes/CircleRoPE.cpp | 32 +++++++++++++ .../service/src/Nodes/CircleRoPE.test.cpp | 46 +++++++++++++++++++ 7 files changed, 93 insertions(+) create mode 100644 compiler/luci/service/src/Nodes/CircleRoPE.cpp create mode 100644 compiler/luci/service/src/Nodes/CircleRoPE.test.cpp diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 2c5c7e8c91f..6c16cb6d210 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -165,6 +165,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final; // loco::TensorShape visit(const luci::CircleBCQGather *node) final; // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final; + // loco::TensorShape visit(const luci::CircleRoPE *node) final; // Virtual // loco::TensorShape visit(const luci::CircleCustomOut *node) final; diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index e725722a986..8db4c8f88a5 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -164,6 +164,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final; // loco::DataType visit(const luci::CircleBCQGather *node) final; // loco::DataType visit(const luci::CircleInstanceNorm *node) final; + // loco::DataType visit(const luci::CircleRoPE *node) final; // Virtual // loco::DataType visit(const luci::CircleInput *node) final; diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index 64c9e4f486f..20b5f14ee04 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -260,6 +260,7 @@ class CloneNode final : public luci::CircleNodeVisitor luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; luci::CircleNode *visit(const luci::CircleGRU *) final; luci::CircleNode *visit(const luci::CircleRmsNorm *) final; + luci::CircleNode *visit(const luci::CircleRoPE *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index a094b681d0c..cd27a6149a1 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -2205,6 +2205,13 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitorinput()).as(); + + return loco::NodeShape{input_shape}; + } + // Virtual loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); } diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index 6b656567071..5d9d22b7cfe 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -584,6 +584,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorinput()); } + loco::DataType visit(const luci::CircleRoPE *node) final + { + return luci::dtype_get(node->input()); + } + // Virtual loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); } diff --git a/compiler/luci/service/src/Nodes/CircleRoPE.cpp b/compiler/luci/service/src/Nodes/CircleRoPE.cpp new file mode 100644 index 00000000000..dc4a628e1ed --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRoPE.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRoPE *node) +{ + if (node->mode() == luci::RoPEMode::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create(); + + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRoPE.test.cpp b/compiler/luci/service/src/Nodes/CircleRoPE.test.cpp new file mode 100644 index 00000000000..a04b2e71975 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRoPE.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include + +TEST(CloneNodeTest, clone_RoPE) +{ + auto g = loco::make_graph(); + auto node_rp = g->nodes()->create(); + node_rp->mode(luci::RoPEMode::GPT_NEOX); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rp, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_rp = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_rp); + ASSERT_EQ(node_rp->mode(), cloned_rp->mode()); +} + +TEST(CloneNodeTest, clone_RoPE_NEG) +{ + auto g = loco::make_graph(); + auto node_rp = g->nodes()->create(); + node_rp->mode(luci::RoPEMode::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_rp, gc.get()); + ASSERT_EQ(nullptr, cloned); +}