diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 65d36c713e02..32a22226a4fe 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -147,6 +147,20 @@ SmallVector getShapePerCTATile(Attribute layout) { } bool isExpensiveView(Type srcType, Type dstType) { + auto tensorSrcType = cast(srcType); + auto tensorDstType = cast(dstType); + auto llSrc = + toLinearLayout(tensorSrcType.getShape(), tensorSrcType.getEncoding()); + auto llDst = + toLinearLayout(tensorDstType.getShape(), tensorDstType.getEncoding()); + // In case there are replicated value we need to make sure the new and old + // layout have matching masks. + for (auto [srcMask, dstMask] : + llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) { + assert(srcMask.first == dstMask.first); + if (srcMask.second != dstMask.second) + return true; + } return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index fd794c7bd5ce..b96244996417 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -40,6 +40,25 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo // ----- +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<2xf32 +// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder +// CHECK: tt.return %[[V]] +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @test_canonicalize_convert_expensive_view2(%arg0: tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> tensor<2xf32, #blocked1> { + %c = ttg.convert_layout %arg0 : tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<2xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<2xf32, #blocked1> -> tensor<2xf32, #blocked1> + tt.return %r : tensor<2xf32, #blocked1> + } +} + +// ----- + // test that the convert does get combined with the view even if the resulting operation // is an efficient view. // CHECK-LABEL: @test_canonicalize_convert_view