diff --git a/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index 127badcfa1560..28f60b1b2cb87 100644 --- a/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -32,7 +32,8 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>, // CHECK: tt.return // CHECK-TMA-LABEL:tt.func @lower_tile_extract_insert -// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr, %[[ARG_1:.*]]: !tt.ptr +// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor}, +// CHECK-TMA-SAME: %[[ARG_1:.*]]: !tt.ptr {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor} // CHECK-TMA: %[[DESC_0:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_0]] // CHECK-TMA: %[[DESC_1:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_1]] // CHECK-TMA: %[[LOAD:.*]] = tt.experimental_descriptor_load %[[DESC_0]] diff --git a/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 3d50ab4cfb0e8..156a039f5a530 100644 --- a/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -137,12 +138,6 @@ void ComputeBoundaryChecks(std::vector& boundary_checks, } struct RewriteFuncOp : mlir::OpRewritePattern { - RewriteFuncOp(mlir::MLIRContext* context, - const stream_executor::DeviceDescription* device_description, - bool tma_enabled) - : OpRewritePattern(context), - device_description(device_description), - tma_enabled(tma_enabled) {} using OpRewritePattern::OpRewritePattern; // Rewrite tensors<> to !tt.ptr @@ -183,6 +178,14 @@ struct RewriteFuncOp : mlir::OpRewritePattern { op.getContext(), new_operand_types, /*result_types=*/{}); auto new_func = rewriter.create(op.getLoc(), op.getName(), new_function_type); + // Transfer the argument attributes from the old function to the new one. + if (op.getArgAttrs().has_value()) { + auto oldArgAttrsArray = op.getArgAttrs().value(); + for (int i = 0; i < oldArgAttrsArray.size(); ++i) { + new_func.setArgAttrs( + i, mlir::cast(oldArgAttrsArray[i])); + } + } rewriter.inlineRegionBefore(op.getRegion(), new_func.getFunctionBody(), new_func.end()); @@ -195,9 +198,6 @@ struct RewriteFuncOp : mlir::OpRewritePattern { return mlir::success(); } - - const stream_executor::DeviceDescription* device_description; - const bool tma_enabled; }; struct RewriteTile : mlir::OpRewritePattern { @@ -215,17 +215,44 @@ struct RewriteTile : mlir::OpRewritePattern { TileOp op, mlir::PatternRewriter& rewriter) const override { ::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter); - // tensor -> !tt.ptr<> - auto cast_to_tensor_ptr_type = - builder - .create( - GetTensorPtrType(builder, - op.getTensor().getType().getElementType()), - op.getTensor()) - .getResult(0); - if (CanUseTMA(tma_enabled, *device_description, op.getTiledTensor().getType())) { + // Add TMA attributes to the equivalent argument in the function. + if (auto block_arg = mlir::dyn_cast(op.getTensor())) { + mlir::Operation* parent_op = block_arg.getOwner()->getParentOp(); + if (auto func_op = mlir::dyn_cast(parent_op)) { + // TODO(manany): Revisit. This needs to be part of the eligibility + // check inside CanUseTMA itself. Should we enforce TileOp to be + // rewritten first and then have the other patterns check a flag? + if (std::distance(block_arg.getUsers().begin(), + block_arg.getUsers().end()) == 1) { + func_op.setArgAttr(block_arg.getArgNumber(), "tt.nv_tma_desc", + builder.getI32IntegerAttr(1)); + // TODO(manany): We need to prefix the attribute name with "tt", + // otherwise tt.func will complain that it has an attribute that + // is not part of the dialect. Is there a better way to do this? + func_op.setArgAttr( + block_arg.getArgNumber(), "tt.tma_descriptor", + builder.getAttr( + op.getTiledTensor().getType().getOriginalShape(), + op.getTiledTensor().getType().getTileShape(), + op.getTiledTensor() + .getType() + .getElementType() + .getIntOrFloatBitWidth() / + 8)); + } + } + } + // tensor -> !tt.ptr<> + auto cast_to_tensor_ptr_type = + builder + .create( + GetTensorPtrType(builder, + op.getTensor().getType().getElementType()), + op.getTensor()) + .getResult(0); + auto reinterpret_tensor_desc = xg::EmitTmaDescriptor(builder, cast_to_tensor_ptr_type, op.getTiledTensor().getType().getTileType()); @@ -238,6 +265,7 @@ struct RewriteTile : mlir::OpRewritePattern { reinterpret_tensor_desc); rewriter.replaceOp(op, cast_desc_ptr_to_tiled_tensor_ptr_type); + return mlir::success(); } @@ -245,6 +273,15 @@ struct RewriteTile : mlir::OpRewritePattern { std::vector dim_order(op.getSizes().size()); std::iota(dim_order.begin(), dim_order.end(), 0); + // tensor -> !tt.ptr<> + auto cast_to_tensor_ptr_type = + builder + .create( + GetTensorPtrType(builder, + op.getTensor().getType().getElementType()), + op.getTensor()) + .getResult(0); + auto tensor_ptr = builder .create( @@ -473,10 +510,11 @@ struct TritonXLAExtractInsertToTritonPass mlir::MLIRContext* mlir_context = &getContext(); mlir::RewritePatternSet patterns(mlir_context); // clang-format off - patterns.add(mlir_context); + patterns.add(mlir_context); patterns.add< RewriteExtract, - RewriteFuncOp, RewriteInsert, RewriteTile >(mlir_context, &device_description, tma_enabled);