From a72f6b4e25ff16b3efc091afeb048fcb1ac8307e Mon Sep 17 00:00:00 2001 From: Annop Wongwathanarat Date: Tue, 24 Sep 2024 09:04:52 +0000 Subject: [PATCH] Remove matmul primitive caching in ideep This removes aarch64 specific code branches for caching oneDNN matmul primitive in ideep. Instead, we rely on primitive caching in oneDNN. Since oneDNN v3.6 matmul primitive with ACL backend leverages the experimental stateless API in ACL and thus can be cached more efficiently. --- include/ideep/operators/matmul.hpp | 54 +++--------------------------- 1 file changed, 4 insertions(+), 50 deletions(-) diff --git a/include/ideep/operators/matmul.hpp b/include/ideep/operators/matmul.hpp index a1bb4e7e..698a6623 100644 --- a/include/ideep/operators/matmul.hpp +++ b/include/ideep/operators/matmul.hpp @@ -48,11 +48,7 @@ struct matmul_forward_params { }; struct matmul_forward : public dnnl::matmul, -#ifdef __aarch64__ - utils::computation_cache > { -#else utils::computation_cache { -#endif using super = dnnl::matmul; // 2-in-1 compute for fp32 op with bias. Bias is disabled if it is empty. @@ -883,20 +879,6 @@ struct matmul_forward : public dnnl::matmul, with_bias, omp_get_max_threads()); -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -907,7 +889,7 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = std::move(super(param.pd)); -#endif + if (param.op_attr.has_scales()) { if (!param.all_scales) { param.all_scales.reset(new std::unordered_map); @@ -1065,20 +1047,7 @@ struct matmul_forward : public dnnl::matmul, op_attr, with_bias, omp_get_max_threads()); -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else + param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -1089,7 +1058,7 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = std::move(super(param.pd)); -#endif + if (param.op_attr.has_scales()) { if (!param.all_scales) { param.all_scales.reset(new std::unordered_map); @@ -1222,20 +1191,6 @@ struct matmul_forward : public dnnl::matmul, omp_get_max_threads()); // Create pd and primitive -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights.get_desc(), bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights.get_desc(), dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -1246,7 +1201,6 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = super(param.pd); -#endif // Create src reorder primitive with runtime scales/zero point auto src_reorder_pd = dnnl::reorder::primitive_desc(aengine, src.get_desc(), aengine, src_desc, src_attr); @@ -1529,4 +1483,4 @@ struct matmul_forward : public dnnl::matmul, } // namespace ideep -#endif \ No newline at end of file +#endif