diff --git a/include/ideep/operators/matmul.hpp b/include/ideep/operators/matmul.hpp index a1bb4e7e..698a6623 100644 --- a/include/ideep/operators/matmul.hpp +++ b/include/ideep/operators/matmul.hpp @@ -48,11 +48,7 @@ struct matmul_forward_params { }; struct matmul_forward : public dnnl::matmul, -#ifdef __aarch64__ - utils::computation_cache > { -#else utils::computation_cache { -#endif using super = dnnl::matmul; // 2-in-1 compute for fp32 op with bias. Bias is disabled if it is empty. @@ -883,20 +879,6 @@ struct matmul_forward : public dnnl::matmul, with_bias, omp_get_max_threads()); -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -907,7 +889,7 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = std::move(super(param.pd)); -#endif + if (param.op_attr.has_scales()) { if (!param.all_scales) { param.all_scales.reset(new std::unordered_map); @@ -1065,20 +1047,7 @@ struct matmul_forward : public dnnl::matmul, op_attr, with_bias, omp_get_max_threads()); -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights_desc, dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else + param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -1089,7 +1058,7 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = std::move(super(param.pd)); -#endif + if (param.op_attr.has_scales()) { if (!param.all_scales) { param.all_scales.reset(new std::unordered_map); @@ -1222,20 +1191,6 @@ struct matmul_forward : public dnnl::matmul, omp_get_max_threads()); // Create pd and primitive -#ifdef __aarch64__ - auto pd_pair = fetch_or_create(key, [&]() { - if (with_bias) { - param.pd = primitive_desc( - aengine, src_desc, weights.get_desc(), bias_desc, dst_desc, op_attr); - } else { - param.pd = primitive_desc( - aengine, src_desc, weights.get_desc(), dst_desc, op_attr); - } - return std::make_pair(param.pd, super(param.pd)); - }); - param.pd = std::move(pd_pair.first); - param.primitive = std::move(pd_pair.second); -#else param.pd = fetch_or_create(key, [&]() { if (with_bias) { return primitive_desc( @@ -1246,7 +1201,6 @@ struct matmul_forward : public dnnl::matmul, } }); param.primitive = super(param.pd); -#endif // Create src reorder primitive with runtime scales/zero point auto src_reorder_pd = dnnl::reorder::primitive_desc(aengine, src.get_desc(), aengine, src_desc, src_attr); @@ -1529,4 +1483,4 @@ struct matmul_forward : public dnnl::matmul, } // namespace ideep -#endif \ No newline at end of file +#endif