Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] Relax the conditions for loop split (pytorch#135335)
Summary This PR Relaxes the conditions for loop split to support dynamic shape cases. Now the conditions that need to be met to apply loop split optimization are as follows: 1. No reduction and no mudular index for all nodes. 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, where the divisor is an integer, the dividend is one of the iter_vars, and this var, i.e. the dimension that needs to be split, is contiguous in all other indexing_exprs. Example: ``` import torch import torch.nn as nn class GN(torch.nn.Module): def __init__(self, num_groups, num_channels): super(GN, self).__init__() self.gn = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return self.gn(x) input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last) m = GN(32, 960).eval() compiled_m = torch.compile(m, dynamic=True) with torch.no_grad(): compiled_m(input) ``` Before loop split, the node's var_ranges: `{z0: s0, z1: s2, z2: s2, z3: 960}` and indexing_exprs: `{'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + z3, 'index1': 32*z0 + (z3//30), 'index2': 30*s2**2, 'index3': z3, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + z3}`. After loop split `z3` will split to `30*z3 + z4`, then the node's var_ranges will be changed to `{z0: s0, z1: s2, z2: s2, z3: 32, z4: 30}` and indexing_exprs will be changed to `{'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + 30*z3 + z4, 'index1': 32*z0 + z3, 'index2': 30*s2**2, 'index3': 30*z3 + z4, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + 30*z3 + z4}` Generated code: - Before: ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], ''' #include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2, const int64_t ks0, const int64_t ks1) { #pragma omp parallel num_threads(112) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L)) { { Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(c10::div_floor_integer(static_cast<int64_t>((15L*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(8L)))); for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(static_cast<int64_t>(ks1*ks1)); x2+=static_cast<int64_t>(1L)) { for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(16L); x3+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16)); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(int64_t x3=static_cast<int64_t>(16L); x3<static_cast<int64_t>(30L); x3+=static_cast<int64_t>(14L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L)); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast<int64_t>(14L), &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(ks1); x1+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(ks1); x2+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(960L); x3+=static_cast<int64_t>(1L)) { auto tmp0 = in_ptr0[static_cast<int64_t>(x3 + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1))))]; auto tmp1 = out_ptr0[static_cast<int64_t>((32L*x0) + (c10::div_floor_integer(static_cast<int64_t>(x3), static_cast<int64_t>(30L))))]; auto tmp3 = out_ptr1[static_cast<int64_t>((32L*x0) + (c10::div_floor_integer(static_cast<int64_t>(x3), static_cast<int64_t>(30L))))]; auto tmp11 = in_ptr1[static_cast<int64_t>(x3)]; auto tmp13 = in_ptr2[static_cast<int64_t>(x3)]; auto tmp2 = decltype(tmp0)(tmp0 - tmp1); auto tmp4 = 30L*(static_cast<int64_t>(ks1*ks1)); auto tmp5 = c10::convert<float>(tmp4); auto tmp6 = tmp3 / tmp5; auto tmp7 = static_cast<float>(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = decltype(tmp2)(tmp2 * tmp9); auto tmp12 = decltype(tmp10)(tmp10 * tmp11); auto tmp14 = decltype(tmp12)(tmp12 + tmp13); out_ptr2[static_cast<int64_t>(x3 + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))))] = tmp14; } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args args.clear() s0 = arg2_1 s2 = arg3_1 assert_size_stride(arg0_1, (960, ), (1, )) assert_size_stride(arg1_1, (960, ), (1, )) assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960)) buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32) cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2) del arg0_1 del arg1_1 del arg4_1 return (buf3, ) ``` After: ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], ''' #include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2, const int64_t ks0, const int64_t ks1) { #pragma omp parallel num_threads(112) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L)) { { Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(c10::div_floor_integer(static_cast<int64_t>((15L*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(8L)))); for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(static_cast<int64_t>(ks1*ks1)); x2+=static_cast<int64_t>(1L)) { for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(16L); x3+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16)); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(int64_t x3=static_cast<int64_t>(16L); x3<static_cast<int64_t>(30L); x3+=static_cast<int64_t>(14L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L)); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast<int64_t>(14L), &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(ks1); x1+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(ks1); x2+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(1L)) { for(int64_t x4=static_cast<int64_t>(0L); x4<static_cast<int64_t>(16L); x4+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16)); auto tmp1 = out_ptr0[static_cast<int64_t>(x3 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast<int64_t>(x3 + (32L*x0))]; auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(16)); auto tmp15 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(16)); auto tmp2 = at::vec::Vectorized<float>(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = 30L*(static_cast<int64_t>(ks1*ks1)); auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = tmp4 / tmp6; auto tmp8 = static_cast<float>(1e-05); auto tmp9 = decltype(tmp7)(tmp7 + tmp8); auto tmp10 = 1 / std::sqrt(tmp9); auto tmp11 = at::vec::Vectorized<float>(tmp10); auto tmp12 = tmp3 * tmp11; auto tmp14 = tmp12 * tmp13; auto tmp16 = tmp14 + tmp15; tmp16.store(out_ptr2 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))))); } for(int64_t x4=static_cast<int64_t>(16L); x4<static_cast<int64_t>(30L); x4+=static_cast<int64_t>(14L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L)); auto tmp1 = out_ptr0[static_cast<int64_t>(x3 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast<int64_t>(x3 + (32L*x0))]; auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(14L)); auto tmp15 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(14L)); auto tmp2 = at::vec::Vectorized<float>(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = 30L*(static_cast<int64_t>(ks1*ks1)); auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = tmp4 / tmp6; auto tmp8 = static_cast<float>(1e-05); auto tmp9 = decltype(tmp7)(tmp7 + tmp8); auto tmp10 = 1 / std::sqrt(tmp9); auto tmp11 = at::vec::Vectorized<float>(tmp10); auto tmp12 = tmp3 * tmp11; auto tmp14 = tmp12 * tmp13; auto tmp16 = tmp14 + tmp15; tmp16.store(out_ptr2 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1))))), static_cast<int64_t>(14L)); } } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args args.clear() s0 = arg2_1 s2 = arg3_1 assert_size_stride(arg0_1, (960, ), (1, )) assert_size_stride(arg1_1, (960, ), (1, )) assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960)) buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32) buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32) cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2) del arg0_1 del arg1_1 del arg4_1 return (buf3, ) ``` Pull Request resolved: pytorch#135335 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
- Loading branch information