diff --git a/src/solver/conv/conv_direct_naive_conv.cpp b/src/solver/conv/conv_direct_naive_conv.cpp index 505f5c9376..bb591e3159 100644 --- a/src/solver/conv/conv_direct_naive_conv.cpp +++ b/src/solver/conv/conv_direct_naive_conv.cpp @@ -361,11 +361,27 @@ GetConv2DFWDSolution(const ExecutionContext& ctx, const ::miopen::conv::ProblemD size_t grid_size = 1; if(problem.IsLayoutDefault()) { - grid_size = static_cast(n) * k; + size_t all_workload = static_cast(n) * k; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else if(problem.IsLayoutNHWC()) { - grid_size = static_cast(group) * n * ho; + size_t all_workload = static_cast(group) * n * ho; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else MIOPEN_THROW("Unsupported layout"); @@ -507,13 +523,30 @@ GetConv3DFWDSolution(const ExecutionContext& ctx, const ::miopen::conv::ProblemD size_t block_size = 256; size_t grid_size = 1; + if(problem.IsLayoutDefault()) { - grid_size = static_cast(n) * k; + size_t all_workload = static_cast(n) * k; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else if(problem.IsLayoutNHWC()) { - grid_size = static_cast(group) * n * do_; + size_t all_workload = static_cast(group) * n * do_; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else MIOPEN_THROW("Unsupported layout"); @@ -867,11 +900,27 @@ GetConv2DBWDSolution(const ExecutionContext& ctx, const ::miopen::conv::ProblemD size_t grid_size = 1; if(problem.IsLayoutDefault()) { - grid_size = static_cast(n) * c; + size_t all_workload = static_cast(n) * c; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else if(problem.IsLayoutNHWC()) { - grid_size = static_cast(group) * n * hi; + size_t all_workload = static_cast(group) * n * hi; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else { @@ -1017,11 +1066,27 @@ GetConv3DBWDSolution(const ExecutionContext& ctx, const ::miopen::conv::ProblemD size_t grid_size = 1; if(problem.IsLayoutDefault()) { - grid_size = static_cast(n) * c; + size_t all_workload = static_cast(n) * c; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else if(problem.IsLayoutNHWC()) { - grid_size = static_cast(group) * n * di; + size_t all_workload = static_cast(group) * n * di; + if(all_workload <= block_size) + { + grid_size = all_workload; + } + else + { + grid_size = (all_workload + block_size - 1) / block_size; + } } else {