diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index a0684a047575..19d4e5f5dcfc 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -27,6 +27,24 @@ using triton::gpu::DotOperandEncodingAttr; using triton::gpu::MmaEncodingAttr; using triton::gpu::SliceEncodingAttr; +struct PatternSharedInfo { + // If a conversion cannot be eliminated with a high-benefit pattern (e.g., + // SimplifyConversion, RematerializeBackward), it will be pushed forward in + // the hope that this will enable the elimination of these conversions later. + // However, pushing a conversion forward can introduce more conversions + // (op(cvt(arg_0), arg_1, ..., arg_n) -> cvt(op(arg_0, cvt(arg_1), ..., + // cvt(arg_n))). This is why the RematerializeForward pattern performs an + // analysis to determine whether these added conversions can be eliminated + // later. The RematerializeBackward pattern, applied after pushing this + // conversion forward, will eliminate these newly added conversions by + // reversing the process achieved with RematerializeForward. This can create + // an infinite loop between these two optimizations. To avoid this, we keep + // track of the conversions that were pushed forward and skip them in the + // RematerializeBackward pattern. A similar kind of loop can occur with the + // RematerializeForward and MoveConvertOutOfLoop patterns. + llvm::DenseMap cvtsPushedForwardMap; +}; + // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- @@ -204,6 +222,7 @@ class SimplifyConversion : public mlir::RewritePattern { // -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) void pushConversionForward(triton::gpu::ConvertLayoutOp cvt, SetVector &cvtSlices, + PatternSharedInfo &sharedInfo, mlir::PatternRewriter &rewriter) { auto srcEncoding = cvt.getOperand().getType().cast().getEncoding(); @@ -237,6 +256,7 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt, newType.getShape(), newType.getElementType(), dstEncoding); auto newCvt = rewriter.create( newOp->getLoc(), newCvtType, newOp->getResult(0)); + sharedInfo.cvtsPushedForwardMap[newCvt] = newCvt->getOperand(0).getDefiningOp(); rewriter.replaceOp(op, newCvt->getResults()); } @@ -346,10 +366,12 @@ class MoveConvertOutOfIf : public mlir::RewritePattern { // class RematerializeForward : public mlir::RewritePattern { + PatternSharedInfo &sharedInfo; + public: - explicit RematerializeForward(mlir::MLIRContext *context) + explicit RematerializeForward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context) {} + 1, context), sharedInfo(sharedInfo) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *cvtOp, @@ -400,7 +422,7 @@ class RematerializeForward : public mlir::RewritePattern { } } - pushConversionForward(cvt, cvtSlices, rewriter); + pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter); return success(); } }; @@ -412,16 +434,25 @@ class RematerializeForward : public mlir::RewritePattern { // even if it means rematerializing all values whose definitions // are reachable from it without passing through any memory operation. class RematerializeBackward : public mlir::RewritePattern { + PatternSharedInfo &sharedInfo; + public: - explicit RematerializeBackward(mlir::MLIRContext *context) + explicit RematerializeBackward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 3, context) {} + 3, context), sharedInfo(sharedInfo) {} + mlir::LogicalResult matchAndRewrite(mlir::Operation *cvt, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(cvt)) return mlir::failure(); + + auto it = sharedInfo.cvtsPushedForwardMap.find(cvt); + if (it != sharedInfo.cvtsPushedForwardMap.end() && + it->second == cvt->getOperand(0).getDefiningOp()) + return mlir::failure(); + // we don't touch block arguments Operation *op = cvt->getOperand(0).getDefiningOp(); if (!op) @@ -456,9 +487,13 @@ class RematerializeBackward : public mlir::RewritePattern { // ----------------------------------------------------------------------------- class MoveConvertOutOfLoop : public mlir::RewritePattern { + PatternSharedInfo &sharedInfo; + public: - explicit MoveConvertOutOfLoop(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} + explicit MoveConvertOutOfLoop(mlir::MLIRContext *context, + PatternSharedInfo &sharedInfo) + : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context), + sharedInfo(sharedInfo) {} SmallVector rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, @@ -529,6 +564,9 @@ class MoveConvertOutOfLoop : public mlir::RewritePattern { // check for (auto *op : cvts) { auto cvt = dyn_cast(op); + auto it = sharedInfo.cvtsPushedForwardMap.find(cvt); + if (it != sharedInfo.cvtsPushedForwardMap.end()) + return mlir::failure(); auto targetType = op->getResultTypes()[0].cast(); auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(), targetType, cvt); @@ -613,12 +651,13 @@ class TritonGPURemoveLayoutConversionsPass ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); + PatternSharedInfo sharedInfo; patterns.add(context); patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); + patterns.add(context, sharedInfo); + patterns.add(context, sharedInfo); + patterns.add(context, sharedInfo); patterns.add(context); patterns.add(context); patterns.add(context);