diff --git a/runtime/onert/core/include/ir/Operations.Include.h b/runtime/onert/core/include/ir/Operations.Include.h index 6352b8ed90b..259c174498d 100644 --- a/runtime/onert/core/include/ir/Operations.Include.h +++ b/runtime/onert/core/include/ir/Operations.Include.h @@ -67,6 +67,7 @@ #include "ir/operation/ResizeBilinear.h" #include "ir/operation/ResizeNearestNeighbor.h" #include "ir/operation/Reverse.h" +#include "ir/operation/RmsNorm.h" #include "ir/operation/RNN.h" #include "ir/operation/Select.h" #include "ir/operation/Shape.h" diff --git a/runtime/onert/core/include/ir/Operations.lst b/runtime/onert/core/include/ir/Operations.lst index 1f91aecb23f..cb19a2ad9a7 100644 --- a/runtime/onert/core/include/ir/Operations.lst +++ b/runtime/onert/core/include/ir/Operations.lst @@ -69,6 +69,7 @@ OP(Reshape) OP(ResizeBilinear) OP(ResizeNearestNeighbor) OP(Reverse) +OP(RmsNorm) OP(RNN) OP(Select) OP(Shape) diff --git a/runtime/onert/core/include/ir/operation/RmsNorm.h b/runtime/onert/core/include/ir/operation/RmsNorm.h new file mode 100644 index 00000000000..416f2e582e4 --- /dev/null +++ b/runtime/onert/core/include/ir/operation/RmsNorm.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_OPERATION_RMS_NORM_H__ +#define __ONERT_IR_OPERATION_RMS_NORM_H__ + +#include "ir/Operation.h" +#include "ir/InternalType.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +class RmsNorm : public Operation +{ +public: + enum Input + { + INPUT = 0, + GAMMA + }; + + struct Param + { + float epsilon; + }; + +public: + RmsNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, + const Param ¶m); + +public: + void accept(OperationVisitor &v) const override; + OpCode opcode() const final { return OpCode::RmsNorm; } + +public: + const Param ¶m() const { return _param; } + +private: + Param _param; +}; + +} // namespace operation +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_OPERATION_RMS_NORM_H__ diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc index e0f28795b9d..933bc757cbd 100644 --- a/runtime/onert/core/src/ir/OperationDumper.cc +++ b/runtime/onert/core/src/ir/OperationDumper.cc @@ -318,6 +318,13 @@ void OperationDumper::visit(const Reverse &node) dumpUnaryInputOp(node, axis); } +void OperationDumper::visit(const RmsNorm &node) +{ + std::string inputs = + "Gamma(" + std::to_string(node.getInputs().at(RmsNorm::Input::GAMMA).value()) + ")"; + dumpUnaryInputOp(node, inputs); +} + void OperationDumper::visit(const RNN &node) { VERBOSE(LIR) << "* RNN" << std::endl; diff --git a/runtime/onert/core/src/ir/OperationDumper.h b/runtime/onert/core/src/ir/OperationDumper.h index 99bf869d586..bcfea6d2590 100644 --- a/runtime/onert/core/src/ir/OperationDumper.h +++ b/runtime/onert/core/src/ir/OperationDumper.h @@ -70,6 +70,7 @@ class OperationDumper : public OperationVisitor void visit(const operation::ResizeBilinear &) override; void visit(const operation::ResizeNearestNeighbor &) override; void visit(const operation::Reverse &) override; + void visit(const operation::RmsNorm &) override; void visit(const operation::RNN &) override; void visit(const operation::Select &node) override; void visit(const operation::Shape &node) override; diff --git a/runtime/onert/core/src/ir/operation/RmsNorm.cc b/runtime/onert/core/src/ir/operation/RmsNorm.cc new file mode 100644 index 00000000000..2f6f4927772 --- /dev/null +++ b/runtime/onert/core/src/ir/operation/RmsNorm.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/RmsNorm.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +void RmsNorm::accept(OperationVisitor &v) const { v.visit(*this); } + +RmsNorm::RmsNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, + const Param ¶m) + : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param} +{ +} + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc index 4a6267b0a34..4ae72668dd9 100644 --- a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc +++ b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc @@ -412,6 +412,14 @@ operation::Reverse generateReverse() return operation::Reverse{OperandIndexSequence{1, 2}, OperandIndexSequence{0}}; } +operation::RmsNorm generateRmsNorm() +{ + operation::RmsNorm::Param param; + param.epsilon = 0.f; + + return operation::RmsNorm{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param}; +} + operation::RNN generateRNN() { operation::RNN::Param param; @@ -750,6 +758,9 @@ TEST(UntrainableOperation, testAllOps) const auto reverse = generateReverse(); verifyOp(reverse); + const auto rms_norm = generateRmsNorm(); + verifyOp(rms_norm); + const auto rnn = generateRNN(); verifyOp(rnn); @@ -1123,6 +1134,12 @@ TEST(UntrainableOperation, neg_TrainableOperationVisitor) EXPECT_ANY_THROW(visitor.invoke(*untrainable)); } + { + const auto rms_norm = generateRmsNorm(); + auto untrainable = generateUntrainableOperation(rms_norm); + EXPECT_ANY_THROW(visitor.invoke(*untrainable)); + } + { const auto rnn = generateRNN(); auto untrainable = generateUntrainableOperation(rnn);