Skip to content

Commit

Permalink
[intel] Sync 'RemoveLayoutConversions.cpp' with Triton using '24b8d43…
Browse files Browse the repository at this point in the history
…' commit

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Feb 5, 2025
1 parent b980165 commit 13b4e92
Showing 1 changed file with 19 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
Expand Down Expand Up @@ -125,9 +126,6 @@ class LayoutRematerialization {
return rematMapping.lookup({value, encoding});
}

bool hasRematValue(Value value, Attribute encoding) {
return rematMapping.contains({value, encoding});
}
void cleanup();
void backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
Expand Down Expand Up @@ -983,8 +981,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
if (Value remat = getRematValue(v, layoutIt->second)) {
mapping.map(v, remat);
continue;
}
if (v.getDefiningOp()) {
Expand Down Expand Up @@ -1206,6 +1204,12 @@ void LayoutRematerialization::backwardRematerialization() {
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
backwardRematerialization(convertOp);
if (!opToDelete.contains(convertOp)) {
// If the conversion didn't get removed, consider it for reuse in future
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}

Expand All @@ -1216,6 +1220,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertOnTopOfExtOrBroadcast(convertOp);
if (!opToDelete.contains(convertOp)) {
// If the conversion didn't get removed, consider it for reuse in future
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}

Expand All @@ -1228,14 +1238,14 @@ void LayoutRematerialization::backwardRematerialization(
dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding()))
if (isa<BlockedEncodingAttr>(dotLayout.getParent()))
return;
Value oldV = convertOp->getOperand(0);
Value oldV = convertOp.getSrc();
LDBG("check backward remat with source " << oldV << " encoding "
<< targetType.getEncoding());
// Check to see if there are existing remat'ed values for the pair of oldValue
// and encoding.
if (hasRematValue(oldV, targetType.getEncoding())) {
// and encoding. Make sure it dominates the current conversion.
Value newV = getRematValue(oldV, targetType.getEncoding());
if (newV && domInfo.properlyDominates(newV, convertOp)) {
// Replace it with the remat'ed value.
Value newV = getRematValue(oldV, targetType.getEncoding());
convertOp.replaceAllUsesWith(newV);
opToDelete.insert(convertOp);
LDBG("found remat'ed value" << newV);
Expand Down

0 comments on commit 13b4e92

Please sign in to comment.