Skip to content

Commit

Permalink
[Snippets][CPU] Added KVCacheMatcher check for LLM in MHATokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 4, 2025
1 parent e50d722 commit 505f29d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/nodes/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ bool Pooling::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, st
return false;
}
#if defined(OV_CPU_WITH_ACL)
if (ov::as_type_ptr<const ov::op::v8::MaxPool>(op) ||
ov::as_type_ptr<const ov::op::v14::MaxPool>(op)) {
if (ov::as_type_ptr<const ov::op::util::MaxPoolBase>(op)->get_kernel() != ov::Shape(2,2)) {
errorMessage = "Pooling indices returning source tensor coordinates is only supported for pool size 2x2";
if (ov::as_type_ptr<const ov::op::v8::MaxPool>(op) || ov::as_type_ptr<const ov::op::v14::MaxPool>(op)) {
if (ov::as_type_ptr<const ov::op::util::MaxPoolBase>(op)->get_kernel() != ov::Shape(2, 2)) {
errorMessage =
"Pooling indices returning source tensor coordinates is only supported for pool size 2x2";
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
# include "cpu/x64/cpu_isa_traits.hpp"
#endif
#include "openvino/core/validation_util.hpp"
#include "openvino/pass/pattern/op/or.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -1039,9 +1040,34 @@ void Transformations::MainSnippets(void) {
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
// We consider the presence of `ScaledDotProductAttentionWithKVCache` and `PagedAttentionExtension` ops
// in the model as a sign that this model is LLM.
const auto is_LLM = ov::op::util::has_op_with_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(model) ||
ov::op::util::has_op_with_type<ov::op::PagedAttentionExtension>(model);
// and `KVCache` subgraph in the model as a sign that this model is LLM.
const auto is_LLM = [this]() -> bool {
using namespace ov::pass::pattern;

const auto past = wrap_type<ov::op::v6::ReadValue>();
const auto convert_past = wrap_type<ov::op::v0::Convert>({past});
const auto gather_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past});
const auto beam_idx = wrap_type<ov::op::v0::Parameter>();
const auto gather_past =
wrap_type<ov::op::v8::Gather>({gather_input, beam_idx, wrap_type<ov::op::v0::Constant>()});
const auto gather_convert = wrap_type<ov::op::v0::Convert>({gather_past});
const auto concat_past_input =
std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past, gather_past, gather_convert});
const auto concat = wrap_type<ov::op::v0::Concat>({concat_past_input, any_input()});
const auto convert_present = wrap_type<ov::op::v0::Convert>({concat});
const auto present_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat, convert_present});
const auto present = wrap_type<ov::op::v6::Assign>({present_input});
const auto kvcache_matcher = std::make_shared<ov::pass::pattern::Matcher>(present, "KVCacheMatcher");

for (const auto& op : model->get_ordered_ops()) {
if (kvcache_matcher->match(op))
return true;
if (ov::is_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(op) ||
ov::is_type<ov::op::PagedAttentionExtension>(op))
return true;
}
return false;
}();

// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const auto is_infer_prc_supported_by_MHA =
Expand Down Expand Up @@ -1121,16 +1147,16 @@ void Transformations::MainSnippets(void) {
ov::is_type<ov::op::v0::Convert>(n) || ov::is_type<ov::op::v1::Divide>(n) ||
ov::is_type<ov::op::v0::Elu>(n) || ov::is_type<ov::op::v0::Exp>(n) ||
ov::is_type<ov::op::v1::Equal>(n) || ov::is_type<ov::op::v0::FakeQuantize>(n) ||
ov::is_type<ov::op::v0::Floor>(n) || ov::is_type<ov::op::v1::FloorMod>(n) ||
ov::is_type<ov::op::v0::Gelu>(n) || ov::is_type<ov::op::v7::Gelu>(n) ||
ov::is_type<ov::op::v1::Greater>(n) || ov::is_type<ov::op::v1::GreaterEqual>(n) ||
ov::is_type<ov::op::v4::HSwish>(n) || ov::is_type<ov::op::v1::LessEqual>(n) ||
ov::is_type<ov::op::v1::Maximum>(n) || ov::is_type<ov::op::v1::Minimum>(n) ||
ov::is_type<ov::op::v4::Mish>(n) || ov::is_type<ov::op::v1::Mod>(n) ||
ov::is_type<ov::op::v1::Multiply>(n) || ov::is_type<ov::op::v0::PRelu>(n) ||
ov::is_type<ov::op::v0::Relu>(n) || ov::is_type<ov::op::v5::Round>(n) ||
ov::is_type<ov::op::v0::Sigmoid>(n) || ov::is_type<ov::op::v0::Sqrt>(n) ||
ov::is_type<ov::op::v1::Subtract>(n) || ov::is_type<ov::op::v4::Swish>(n) ||
ov::is_type<ov::op::v0::Floor>(n) || ov::is_type<ov::op::v1::FloorMod>(n) ||
ov::is_type<ov::op::v0::Gelu>(n) || ov::is_type<ov::op::v7::Gelu>(n) ||
ov::is_type<ov::op::v1::Greater>(n) || ov::is_type<ov::op::v1::GreaterEqual>(n) ||
ov::is_type<ov::op::v4::HSwish>(n) || ov::is_type<ov::op::v1::LessEqual>(n) ||
ov::is_type<ov::op::v1::Maximum>(n) || ov::is_type<ov::op::v1::Minimum>(n) ||
ov::is_type<ov::op::v4::Mish>(n) || ov::is_type<ov::op::v1::Mod>(n) ||
ov::is_type<ov::op::v1::Multiply>(n) || ov::is_type<ov::op::v0::PRelu>(n) ||
ov::is_type<ov::op::v0::Relu>(n) || ov::is_type<ov::op::v5::Round>(n) ||
ov::is_type<ov::op::v0::Sigmoid>(n) || ov::is_type<ov::op::v0::Sqrt>(n) ||
ov::is_type<ov::op::v1::Subtract>(n) || ov::is_type<ov::op::v4::Swish>(n) ||
ov::is_type<ov::op::v0::Tanh>(n));
#else
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant,
Expand Down
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/utils/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ void Verbose::printInfo() {
std::string dim_str = {};
if (DnnlExtensionUtils::ElementTypeToDataType(desc->getPrecision(), DnnlExtensionUtils::nothrow_tag{})) {
if (auto dnnl_desc = MemoryDescUtils::convertToDnnlMemoryDesc(desc)->getDnnlDesc()) {
fmt_str = dnnl::impl::md2fmt_str("", dnnl_desc.get(), dnnl::impl::format_kind_t::dnnl_format_kind_undef);
fmt_str =
dnnl::impl::md2fmt_str("", dnnl_desc.get(), dnnl::impl::format_kind_t::dnnl_format_kind_undef);
std::string dim_str = dnnl::impl::md2dim_str(dnnl_desc.get());
} else {
fmt_str = "empty";
Expand Down

0 comments on commit 505f29d

Please sign in to comment.