Skip to content

Commit

Permalink
Fix FA v2 hanging issue when BLOCK_N=32 (#274)
Browse files Browse the repository at this point in the history
* Fix FA v2 hanging issue when BLOCK_N=32

* Fix broken tests
  • Loading branch information
oplavsic authored Aug 10, 2023
1 parent a1f4ee6 commit 398d2c7
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;

struct PatternSharedInfo {
// If a conversion cannot be eliminated with a high-benefit pattern (e.g.,
// SimplifyConversion, RematerializeBackward), it will be pushed forward in
// the hope that this will enable the elimination of these conversions later.
// However, pushing a conversion forward can introduce more conversions
// (op(cvt(arg_0), arg_1, ..., arg_n) -> cvt(op(arg_0, cvt(arg_1), ...,
// cvt(arg_n))). This is why the RematerializeForward pattern performs an
// analysis to determine whether these added conversions can be eliminated
// later. The RematerializeBackward pattern, applied after pushing this
// conversion forward, will eliminate these newly added conversions by
// reversing the process achieved with RematerializeForward. This can create
// an infinite loop between these two optimizations. To avoid this, we keep
// track of the conversions that were pushed forward and skip them in the
// RematerializeBackward pattern. A similar kind of loop can occur with the
// RematerializeForward and MoveConvertOutOfLoop patterns.
llvm::DenseMap<Operation *, Operation *> cvtsPushedForwardMap;
};

// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -204,6 +222,7 @@ class SimplifyConversion : public mlir::RewritePattern {
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
SetVector<Operation *> &cvtSlices,
PatternSharedInfo &sharedInfo,
mlir::PatternRewriter &rewriter) {
auto srcEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
Expand Down Expand Up @@ -237,6 +256,7 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
newType.getShape(), newType.getElementType(), dstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
newOp->getLoc(), newCvtType, newOp->getResult(0));
sharedInfo.cvtsPushedForwardMap[newCvt] = newCvt->getOperand(0).getDefiningOp();
rewriter.replaceOp(op, newCvt->getResults());
}

Expand Down Expand Up @@ -346,10 +366,12 @@ class MoveConvertOutOfIf : public mlir::RewritePattern {

//
class RematerializeForward : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;

public:
explicit RematerializeForward(mlir::MLIRContext *context)
explicit RematerializeForward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
1, context), sharedInfo(sharedInfo) {}

mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvtOp,
Expand Down Expand Up @@ -400,7 +422,7 @@ class RematerializeForward : public mlir::RewritePattern {
}
}

pushConversionForward(cvt, cvtSlices, rewriter);
pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter);
return success();
}
};
Expand All @@ -412,16 +434,25 @@ class RematerializeForward : public mlir::RewritePattern {
// even if it means rematerializing all values whose definitions
// are reachable from it without passing through any memory operation.
class RematerializeBackward : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;

public:
explicit RematerializeBackward(mlir::MLIRContext *context)
explicit RematerializeBackward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
3, context) {}
3, context), sharedInfo(sharedInfo) {}


mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvt,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
return mlir::failure();

auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
if (it != sharedInfo.cvtsPushedForwardMap.end() &&
it->second == cvt->getOperand(0).getDefiningOp())
return mlir::failure();

// we don't touch block arguments
Operation *op = cvt->getOperand(0).getDefiningOp();
if (!op)
Expand Down Expand Up @@ -456,9 +487,13 @@ class RematerializeBackward : public mlir::RewritePattern {
// -----------------------------------------------------------------------------

class MoveConvertOutOfLoop : public mlir::RewritePattern {
PatternSharedInfo &sharedInfo;

public:
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context,
PatternSharedInfo &sharedInfo)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context),
sharedInfo(sharedInfo) {}

SmallVector<Value, 4>
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
Expand Down Expand Up @@ -529,6 +564,9 @@ class MoveConvertOutOfLoop : public mlir::RewritePattern {
// check
for (auto *op : cvts) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
auto it = sharedInfo.cvtsPushedForwardMap.find(cvt);
if (it != sharedInfo.cvtsPushedForwardMap.end())
return mlir::failure();
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
targetType, cvt);
Expand Down Expand Up @@ -613,12 +651,13 @@ class TritonGPURemoveLayoutConversionsPass
ModuleOp m = getOperation();

mlir::RewritePatternSet patterns(context);
PatternSharedInfo sharedInfo;

patterns.add<SimplifyConversion>(context);
patterns.add<SimplifyReduceCvt>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<RematerializeBackward>(context, sharedInfo);
patterns.add<RematerializeForward>(context, sharedInfo);
patterns.add<MoveConvertOutOfLoop>(context, sharedInfo);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<ConvertDotConvert>(context);
Expand Down

0 comments on commit 398d2c7

Please sign in to comment.