diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 1f68592f9b..c2a2ef06ca 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -536,7 +536,7 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( if(problem.GetOut().GetType() == miopenHalf) return use_fp32_global_split_on_fp16; if(problem.GetOut().GetType() == miopenBFloat16) - return need_set_zero; + return config.gemm_k_global_split > 0; return false; }(); const auto is_nchw = problem.IsLayoutDefault(); @@ -849,7 +849,10 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( if(problem.GetOut().GetType() == miopenHalf) return use_fp32_global_split_on_fp16; if(problem.GetOut().GetType() == miopenBFloat16) - return need_set_zero; + { + return (y < stride_h || x < stride_w || dilation_h != 1 || dilation_w != 1 || + config.gemm_k_global_split > 0); + } return false; }(); const auto is_nchw = problem.IsLayoutDefault();