diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc index 555c46d6a2a..e65efa8307e 100644 --- a/runtime/onert/core/src/compiler/ShapeValidator.cc +++ b/runtime/onert/core/src/compiler/ShapeValidator.cc @@ -1133,9 +1133,15 @@ void ShapeValidator::visit(const ir::operation::RmsNorm &node) const auto ifm_index{node.getInputs().at(ir::operation::RmsNorm::Input::INPUT)}; const auto gamma_index{node.getInputs().at(ir::operation::RmsNorm::Input::GAMMA)}; - OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape()); - OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1); + const auto &ifm_shape = operands.at(ifm_index).shape(); + const auto &ofm_shape = operands.at(ofm_index).shape(); + const auto &gamma_shape = operands.at(gamma_index).shape(); + + OP_REQUIRES(ifm_shape.rank() == 3 || ifm_shape.rank() == 4); + OP_REQUIRES(ifm_shape == ofm_shape); + OP_REQUIRES(gamma_shape.rank() == 1); + OP_REQUIRES((gamma_shape.dim(0) == 1) || + (gamma_shape.dim(0) == ifm_shape.dim(ifm_shape.rank() - 1))); } } // namespace compiler