From 13b4e925219cc97785b999d0b1132b5af824a28a Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 5 Feb 2025 23:05:51 +0100 Subject: [PATCH] [intel] Sync 'RemoveLayoutConversions.cpp' with Triton using '24b8d43' commit Signed-off-by: Anatoly Myachev --- .../RemoveLayoutConversions.cpp | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 6d72de79d6..36f63edf3a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1,6 +1,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" @@ -125,9 +126,6 @@ class LayoutRematerialization { return rematMapping.lookup({value, encoding}); } - bool hasRematValue(Value value, Attribute encoding) { - return rematMapping.contains({value, encoding}); - } void cleanup(); void backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); @@ -983,8 +981,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. - if (hasRematValue(v, layoutIt->second)) { - mapping.map(v, getRematValue(v, layoutIt->second)); + if (Value remat = getRematValue(v, layoutIt->second)) { + mapping.map(v, remat); continue; } if (v.getDefiningOp()) { @@ -1206,6 +1204,12 @@ void LayoutRematerialization::backwardRematerialization() { [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { backwardRematerialization(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } } } @@ -1216,6 +1220,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { hoistConvertOnTopOfExtOrBroadcast(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } } } @@ -1228,14 +1238,14 @@ void LayoutRematerialization::backwardRematerialization( dyn_cast(targetType.getEncoding())) if (isa(dotLayout.getParent())) return; - Value oldV = convertOp->getOperand(0); + Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " << targetType.getEncoding()); // Check to see if there are existing remat'ed values for the pair of oldValue - // and encoding. - if (hasRematValue(oldV, targetType.getEncoding())) { + // and encoding. Make sure it dominates the current conversion. + Value newV = getRematValue(oldV, targetType.getEncoding()); + if (newV && domInfo.properlyDominates(newV, convertOp)) { // Replace it with the remat'ed value. - Value newV = getRematValue(oldV, targetType.getEncoding()); convertOp.replaceAllUsesWith(newV); opToDelete.insert(convertOp); LDBG("found remat'ed value" << newV);