diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index bc65658a1e2ec1..e142b56d02fc6a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3432,29 +3432,27 @@ def forward(self, x): return self.group_norm(x) options = itertools.product( - vec_dtypes, [torch.contiguous_format, torch.channels_last] + vec_dtypes, [torch.contiguous_format, torch.channels_last], [True, False] ) - for dtype, fmt in options: + for dtype, fmt, dynamic in options: torch._dynamo.reset() metrics.reset() mod = M().eval() x = torch.randn((2, 90, 6, 6), dtype=dtype).to(memory_format=fmt) with torch.no_grad(): - self.common(mod, (x,)) + expected = mod(x) + compiled_m = torch.compile(mod, dynamic=dynamic) + actual, code = run_and_get_cpp_code(compiled_m, x) + self.assertEqual(expected, actual) # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) - # check loop split optimization - if fmt == torch.channels_last: - torch._dynamo.reset() - metrics.reset() - with torch.no_grad(): - opt_mod = torch.compile(mod) - _, code = run_and_get_cpp_code(opt_mod, x) - # check that there are no non_contiguous loads - FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( - code - ) + # check loop split optimization + if fmt == torch.channels_last: + # check that there are no non_contiguous loads + FileCheck().check_count( + "__at_align__ std::array", 0, exactly=True + ).run(code) def test_int_div_vec(self): def fn(x, y, mode): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 8bbe8023a312af..b483f588f99d0e 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4159,9 +4159,9 @@ def try_loop_split(self, nodes: List[SchedulerNode]): When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop to avoid non-contiguous loads, subject to the following conditions: 1. No reduction and no mudular index for all nodes. - 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, - we can get the dimension that needs to be split, and the split dimension is contiguous - in all other indexing_exprs. + 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, + where the divisor is an integer, the dividend is one of the iter_vars, and this var, + i.e. the dimension that needs to be split, is contiguous in all other indexing_exprs. For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, @@ -4182,8 +4182,8 @@ def try_loop_split(self, nodes: List[SchedulerNode]): split_var = None split_number = None - divide_index_name = None num_div = 0 + div_expr_ = None match_div = False matched_node = None @@ -4191,25 +4191,27 @@ def try_loop_split(self, nodes: List[SchedulerNode]): assert isinstance(node.node, ir.ComputedBuffer) _, original_body, _ = node.node.get_default_sizes_body() for name, expr in original_body.indexing_exprs.items(): - num_div += expr.count(FloorDiv) - if num_div > 1: - return nodes - if expr.count(FloorDiv) == 1: - div_expr = expr.find(FloorDiv).pop() - split_var = div_expr.args[0] - split_number = div_expr.args[1] - divide_index_name = name + for div_expr in expr.find(FloorDiv): if ( - isinstance(split_number, sympy.core.numbers.Integer) - and isinstance(split_var, sympy.core.symbol.Symbol) - and split_var in original_body.iter_vars - and divide_index_name is not None + any(div_expr.has(var) for var in original_body.iter_vars) + and div_expr != div_expr_ + ): + div_expr_ = div_expr + num_div += 1 + if num_div > 1: + return nodes + if ( + isinstance(div_expr.args[1], sympy.core.numbers.Integer) + and div_expr.args[0] in original_body.iter_vars + and name is not None and all( - stride_at_vec_range(expr, split_var) == 1 - for name, expr in original_body.indexing_exprs.items() - if name != divide_index_name + stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1) + for name_, expr_ in original_body.indexing_exprs.items() + if name_ != name ) ): + split_var = div_expr.args[0] + split_number = div_expr.args[1] match_div = True matched_node = node