From 687e5cf8c5a511494c41bf061653e45c3c521bf0 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Fri, 13 Sep 2024 01:40:49 -0700 Subject: [PATCH] [inductor] Relax the conditions for loop split (#135335) Summary This PR Relaxes the conditions for loop split to support dynamic shape cases. Now the conditions that need to be met to apply loop split optimization are as follows: 1. No reduction and no mudular index for all nodes. 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, where the divisor is an integer, the dividend is one of the iter_vars, and this var, i.e. the dimension that needs to be split, is contiguous in all other indexing_exprs. Example: ``` import torch import torch.nn as nn class GN(torch.nn.Module): def __init__(self, num_groups, num_channels): super(GN, self).__init__() self.gn = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return self.gn(x) input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last) m = GN(32, 960).eval() compiled_m = torch.compile(m, dynamic=True) with torch.no_grad(): compiled_m(input) ``` Before loop split, the node's var_ranges: `{z0: s0, z1: s2, z2: s2, z3: 960}` and indexing_exprs: `{'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + z3, 'index1': 32*z0 + (z3//30), 'index2': 30*s2**2, 'index3': z3, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + z3}`. After loop split `z3` will split to `30*z3 + z4`, then the node's var_ranges will be changed to `{z0: s0, z1: s2, z2: s2, z3: 32, z4: 30}` and indexing_exprs will be changed to `{'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + 30*z3 + z4, 'index1': 32*z0 + z3, 'index2': 30*s2**2, 'index3': 30*z3 + z4, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + 30*z3 + z4}` Generated code: - Before: ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], ''' #include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2, const int64_t ks0, const int64_t ks1) { #pragma omp parallel num_threads(112) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(int64_t x0=static_cast(0L); x0(ks0); x0+=static_cast(1L)) { for(int64_t x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(c10::div_floor_integer(static_cast((15L*(static_cast(ks1*ks1)))), static_cast(8L)))); for(int64_t x2=static_cast(0L); x2(static_cast(ks1*ks1)); x2+=static_cast(1L)) { for(int64_t x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(16)); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(int64_t x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(14L)); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast(14L), &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(int64_t x0=static_cast(0L); x0(ks0); x0+=static_cast(1L)) { for(int64_t x1=static_cast(0L); x1(ks1); x1+=static_cast(1L)) { #pragma GCC ivdep for(int64_t x2=static_cast(0L); x2(ks1); x2+=static_cast(1L)) { #pragma GCC ivdep for(int64_t x3=static_cast(0L); x3(960L); x3+=static_cast(1L)) { auto tmp0 = in_ptr0[static_cast(x3 + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast(ks1*ks1))))]; auto tmp1 = out_ptr0[static_cast((32L*x0) + (c10::div_floor_integer(static_cast(x3), static_cast(30L))))]; auto tmp3 = out_ptr1[static_cast((32L*x0) + (c10::div_floor_integer(static_cast(x3), static_cast(30L))))]; auto tmp11 = in_ptr1[static_cast(x3)]; auto tmp13 = in_ptr2[static_cast(x3)]; auto tmp2 = decltype(tmp0)(tmp0 - tmp1); auto tmp4 = 30L*(static_cast(ks1*ks1)); auto tmp5 = c10::convert(tmp4); auto tmp6 = tmp3 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = decltype(tmp2)(tmp2 * tmp9); auto tmp12 = decltype(tmp10)(tmp10 * tmp11); auto tmp14 = decltype(tmp12)(tmp12 + tmp13); out_ptr2[static_cast(x3 + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1)))))] = tmp14; } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args args.clear() s0 = arg2_1 s2 = arg3_1 assert_size_stride(arg0_1, (960, ), (1, )) assert_size_stride(arg1_1, (960, ), (1, )) assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960)) buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32) cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2) del arg0_1 del arg1_1 del arg4_1 return (buf3, ) ``` After: ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], ''' #include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2, const int64_t ks0, const int64_t ks1) { #pragma omp parallel num_threads(112) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(int64_t x0=static_cast(0L); x0(ks0); x0+=static_cast(1L)) { for(int64_t x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(c10::div_floor_integer(static_cast((15L*(static_cast(ks1*ks1)))), static_cast(8L)))); for(int64_t x2=static_cast(0L); x2(static_cast(ks1*ks1)); x2+=static_cast(1L)) { for(int64_t x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(16)); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(int64_t x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(14L)); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast(14L), &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(int64_t x0=static_cast(0L); x0(ks0); x0+=static_cast(1L)) { for(int64_t x1=static_cast(0L); x1(ks1); x1+=static_cast(1L)) { #pragma GCC ivdep for(int64_t x2=static_cast(0L); x2(ks1); x2+=static_cast(1L)) { #pragma GCC ivdep for(int64_t x3=static_cast(0L); x3(32L); x3+=static_cast(1L)) { for(int64_t x4=static_cast(0L); x4(16L); x4+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(16)); auto tmp1 = out_ptr0[static_cast(x3 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x3 + (32L*x0))]; auto tmp13 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x4 + (30L*x3)), static_cast(16)); auto tmp15 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x4 + (30L*x3)), static_cast(16)); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = 30L*(static_cast(ks1*ks1)); auto tmp6 = c10::convert(tmp5); auto tmp7 = tmp4 / tmp6; auto tmp8 = static_cast(1e-05); auto tmp9 = decltype(tmp7)(tmp7 + tmp8); auto tmp10 = 1 / std::sqrt(tmp9); auto tmp11 = at::vec::Vectorized(tmp10); auto tmp12 = tmp3 * tmp11; auto tmp14 = tmp12 * tmp13; auto tmp16 = tmp14 + tmp15; tmp16.store(out_ptr2 + static_cast(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1)))))); } for(int64_t x4=static_cast(16L); x4(30L); x4+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast(ks1*ks1)))), static_cast(14L)); auto tmp1 = out_ptr0[static_cast(x3 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x3 + (32L*x0))]; auto tmp13 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x4 + (30L*x3)), static_cast(14L)); auto tmp15 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x4 + (30L*x3)), static_cast(14L)); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = 30L*(static_cast(ks1*ks1)); auto tmp6 = c10::convert(tmp5); auto tmp7 = tmp4 / tmp6; auto tmp8 = static_cast(1e-05); auto tmp9 = decltype(tmp7)(tmp7 + tmp8); auto tmp10 = 1 / std::sqrt(tmp9); auto tmp11 = at::vec::Vectorized(tmp10); auto tmp12 = tmp3 * tmp11; auto tmp14 = tmp12 * tmp13; auto tmp16 = tmp14 + tmp15; tmp16.store(out_ptr2 + static_cast(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast((static_cast(ks1*ks1))), static_cast(ks1))))), static_cast(14L)); } } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args args.clear() s0 = arg2_1 s2 = arg3_1 assert_size_stride(arg0_1, (960, ), (1, )) assert_size_stride(arg1_1, (960, ), (1, )) assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960)) buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32) cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2) del arg0_1 del arg1_1 del arg4_1 return (buf3, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135335 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_cpu_repro.py | 26 ++++++++++----------- torch/_inductor/codegen/cpp.py | 40 +++++++++++++++++---------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index bc65658a1e2ec..e142b56d02fc6 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3432,29 +3432,27 @@ def forward(self, x): return self.group_norm(x) options = itertools.product( - vec_dtypes, [torch.contiguous_format, torch.channels_last] + vec_dtypes, [torch.contiguous_format, torch.channels_last], [True, False] ) - for dtype, fmt in options: + for dtype, fmt, dynamic in options: torch._dynamo.reset() metrics.reset() mod = M().eval() x = torch.randn((2, 90, 6, 6), dtype=dtype).to(memory_format=fmt) with torch.no_grad(): - self.common(mod, (x,)) + expected = mod(x) + compiled_m = torch.compile(mod, dynamic=dynamic) + actual, code = run_and_get_cpp_code(compiled_m, x) + self.assertEqual(expected, actual) # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) - # check loop split optimization - if fmt == torch.channels_last: - torch._dynamo.reset() - metrics.reset() - with torch.no_grad(): - opt_mod = torch.compile(mod) - _, code = run_and_get_cpp_code(opt_mod, x) - # check that there are no non_contiguous loads - FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( - code - ) + # check loop split optimization + if fmt == torch.channels_last: + # check that there are no non_contiguous loads + FileCheck().check_count( + "__at_align__ std::array", 0, exactly=True + ).run(code) def test_int_div_vec(self): def fn(x, y, mode): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 8bbe8023a312a..b483f588f99d0 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4159,9 +4159,9 @@ def try_loop_split(self, nodes: List[SchedulerNode]): When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop to avoid non-contiguous loads, subject to the following conditions: 1. No reduction and no mudular index for all nodes. - 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, - we can get the dimension that needs to be split, and the split dimension is contiguous - in all other indexing_exprs. + 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, + where the divisor is an integer, the dividend is one of the iter_vars, and this var, + i.e. the dimension that needs to be split, is contiguous in all other indexing_exprs. For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, @@ -4182,8 +4182,8 @@ def try_loop_split(self, nodes: List[SchedulerNode]): split_var = None split_number = None - divide_index_name = None num_div = 0 + div_expr_ = None match_div = False matched_node = None @@ -4191,25 +4191,27 @@ def try_loop_split(self, nodes: List[SchedulerNode]): assert isinstance(node.node, ir.ComputedBuffer) _, original_body, _ = node.node.get_default_sizes_body() for name, expr in original_body.indexing_exprs.items(): - num_div += expr.count(FloorDiv) - if num_div > 1: - return nodes - if expr.count(FloorDiv) == 1: - div_expr = expr.find(FloorDiv).pop() - split_var = div_expr.args[0] - split_number = div_expr.args[1] - divide_index_name = name + for div_expr in expr.find(FloorDiv): if ( - isinstance(split_number, sympy.core.numbers.Integer) - and isinstance(split_var, sympy.core.symbol.Symbol) - and split_var in original_body.iter_vars - and divide_index_name is not None + any(div_expr.has(var) for var in original_body.iter_vars) + and div_expr != div_expr_ + ): + div_expr_ = div_expr + num_div += 1 + if num_div > 1: + return nodes + if ( + isinstance(div_expr.args[1], sympy.core.numbers.Integer) + and div_expr.args[0] in original_body.iter_vars + and name is not None and all( - stride_at_vec_range(expr, split_var) == 1 - for name, expr in original_body.indexing_exprs.items() - if name != divide_index_name + stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1) + for name_, expr_ in original_body.indexing_exprs.items() + if name_ != name ) ): + split_var = div_expr.args[0] + split_number = div_expr.args[1] match_div = True matched_node = node