Skip to content

Commit

Permalink
COPYBARA SYNC:
Browse files Browse the repository at this point in the history
  - 9ca8fb2 [compiler] opt identity slice (#304)
  - 9b989ae [torch-frontend] add decompose-on-torch pass (#398)
  - 77a582f [tf-frontend] fix for gcc8.5 (#397)
  - e9f731e [torch-frontend] fix lowering math ops to custom call (#3...
  - d8f1ffd [compiler] add e2e generator and diff checker for MLPInfe...
  - 2b90dd9 [doc] update custom call doc (#394)
  - fe8de9c [tf-frontend] update tensorflow to bc42c0c1 (#387)
  - 12f2fd6 [torch-frontend] support convert math ops to custom call ...
  - e1546f4 [*] bump version to 1.9.0.0 (#391)
  - 5ed0761 [e2e] fix e2e after update llvm (#390)
  (And 10 more changes)

GitOrigin-RevId: 9ca8fb2
  • Loading branch information
Vremold committed Jul 10, 2024
1 parent bd82aab commit 5d9c578
Show file tree
Hide file tree
Showing 79 changed files with 2,540 additions and 418 deletions.
3 changes: 3 additions & 0 deletions compiler/include/byteir/Dialect/mhlo/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ bool isDenseMhloConstantValue(Value val);
template <typename RegionOp, typename Op = mhlo::ReduceOp>
bool isRegularReduceOp(Op op);

// Return true if slice region is continuous
bool isSliceContinuousSubview(mhlo::SliceOp op);

// return cumsum's index, return nullopt if not a cumsum op
std::optional<int64_t> getCumsumIndex(mhlo::ReduceWindowOp op);

Expand Down
36 changes: 33 additions & 3 deletions compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "byteir/Conversion/MemrefToByre/MemrefToByre.h"
#include "byteir/Dialect/Byre/ByreDialect.h"
#include "byteir/Dialect/Byre/Common.h"
#include "byteir/Utils/MemUtils.h"
#include "byteir/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand All @@ -41,7 +42,7 @@ class ConvertReshapeLikeOpToByrePattern : public OpConversionPattern<OpTy> {
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getType().getLayout().isIdentity())
if (!isStaticShapeAndContiguousRowMajorEx(op.getType()))
return failure();

rewriter.replaceOpWithNewOp<byre::AliasOp>(op, op.getResult().getType(),
Expand Down Expand Up @@ -78,7 +79,7 @@ class ConvertSubViewOpToByrePattern
LogicalResult
matchAndRewrite(memref::SubViewOp op, memref::SubViewOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getType().getLayout().isIdentity())
if (!isStaticShapeAndContiguousRowMajorEx(op.getType()))
return failure();

if (!op.getSource().getType().getLayout().isIdentity())
Expand All @@ -90,6 +91,32 @@ class ConvertSubViewOpToByrePattern
}
};

template <typename OpTy>
class ConvertMemrefCastOpToBtrePattern : public OpConversionPattern<OpTy> {
public:
ConvertMemrefCastOpToBtrePattern(MLIRContext *ctx)
: OpConversionPattern<OpTy>(ctx) {}

LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isStaticShapeAndContiguousRowMajorEx(
op->getOperand(0).getType().template cast<MemRefType>()))
return failure();

int64_t offset = 0;
if constexpr (std::is_same_v<OpTy, memref::ReinterpretCastOp>) {
auto srcMemref = op.getSource().getType().template cast<MemRefType>();
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(srcMemref, strides, offset)))
return failure();
}
rewriter.replaceOpWithNewOp<byre::AliasOp>(op, op.getType(),
adaptor.getSource(), offset);
return success();
}
}; // ConvertMemrefCastOpToBtrePattern

class ConvertMemrefCopyOpToByrePattern
: public OpConversionPattern<memref::CopyOp> {
public:
Expand Down Expand Up @@ -194,7 +221,10 @@ void mlir::populateMemrefToByrePattern(RewritePatternSet &patterns) {
ConvertGetGlobalOpToByrePattern,
ConvertReshapeLikeOpToByrePattern<memref::CollapseShapeOp>,
ConvertReshapeLikeOpToByrePattern<memref::ExpandShapeOp>,
ConvertSubViewOpToByrePattern>(patterns.getContext());
ConvertSubViewOpToByrePattern,
ConvertMemrefCastOpToBtrePattern<memref::CastOp>,
ConvertMemrefCastOpToBtrePattern<memref::ReinterpretCastOp>>(
patterns.getContext());
}

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
14 changes: 11 additions & 3 deletions compiler/lib/Dialect/Byre/IR/ByreDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,17 @@ struct CollapseAliasChain : public OpRewritePattern<AliasOp> {
LogicalResult matchAndRewrite(AliasOp aliasOp,
PatternRewriter &rewriter) const override {
if (auto sourceOp = aliasOp.getSource().getDefiningOp<AliasOp>()) {
rewriter.replaceOpWithNewOp<AliasOp>(
aliasOp, aliasOp.getTarget().getType(), sourceOp.getSource(),
aliasOp.getOffset() + sourceOp.getOffset());
auto srcElemBitwidth = cast<MemRefType>(sourceOp.getSource().getType())
.getElementType()
.getIntOrFloatBitWidth();
auto curElemBitwidth = cast<MemRefType>(aliasOp.getTarget().getType())
.getElementType()
.getIntOrFloatBitWidth();
auto newOffset = aliasOp.getOffset() * curElemBitwidth / srcElemBitwidth +
sourceOp.getOffset();
rewriter.replaceOpWithNewOp<AliasOp>(aliasOp,
aliasOp.getTarget().getType(),
sourceOp.getSource(), newOffset);
return success();
}
return failure();
Expand Down
75 changes: 74 additions & 1 deletion compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "byteir/Dialect/MemRef/Transforms/RemoveCopy.h"
#include "byteir/Dialect/MemRef/Utils/MemEffect.h"
#include "byteir/Utils/Hoist.h"
#include "byteir/Utils/MemUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -102,6 +103,45 @@ void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp,
rewriter.eraseOp(op);
}

// Check whether all uses of oldValue can be safely replaced with newValue after
// casting.
bool anyIncompatibleUseWithCast(Value oldValue, Value newValue) {
bool incompatible = llvm::any_of(oldValue.getUses(), [](OpOperand &operand) {
Operation *op = operand.getOwner();
Dialect *dialect = op->getDialect();
return llvm::isa<memref::CollapseShapeOp, memref::ExpandShapeOp,
func::CallOp>(op) ||
(dialect && dialect->getNamespace() == "byre");
});
incompatible &= (!isStaticShapeAndContiguousRowMajorEx(
oldValue.getType().cast<MemRefType>()) ||
!isStaticShapeAndContiguousRowMajorEx(
newValue.getType().cast<MemRefType>()));
return incompatible;
}

SmallVector<Operation *> getReshapeOp(Value value) {
SmallVector<Operation *> reshapeOps;
auto operation = value.getDefiningOp();
while (operation &&
isa<memref::CollapseShapeOp, memref::ExpandShapeOp>(operation)) {
reshapeOps.push_back(operation);
value = operation->getOperand(0);
operation = value.getDefiningOp();
}
if (operation && isa<memref::AllocOp>(operation))
return reshapeOps;
return {};
}

int64_t extractOffset(MemRefType memref) {
int64_t offset{0};
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(memref, strides, offset)))
return 0;
return offset;
}

class RemoveCopyPattern : public OpRewritePattern<memref::CopyOp> {
public:
RemoveCopyPattern(MLIRContext *context, DominanceInfo &dom)
Expand Down Expand Up @@ -217,7 +257,8 @@ class RemoveCopyPattern : public OpRewritePattern<memref::CopyOp> {
// we prefer target alloc over src alloc in this implementation
if (auto targetAlloc = target.getDefiningOp<memref::AllocOp>()) {
if (auto srcDef = src.getDefiningOp()) {
if (isa<memref::AllocOp, memref::SubViewOp>(srcDef))
if (isa<memref::AllocOp, memref::SubViewOp, memref::ExpandShapeOp,
memref::ExpandShapeOp>(srcDef))
hoistUpOpInBlock(srcDef, domInfo);
}

Expand All @@ -232,6 +273,38 @@ class RemoveCopyPattern : public OpRewritePattern<memref::CopyOp> {
replaceUsesAndPropagateType(rewriter, targetAlloc, src);
return success();
}

if (!anyIncompatibleUseWithCast(target, src)) {
// The memref of source and target are contiguous, cast source value to
// the same type with target. As `byre.alias` could handle source with
// offset, `memref.(reinterpret)cast` would be converted to `byre.alias`
// in pass `memref-to-byre`.
LLVM_DEBUG(llvm::dbgs()
<< "contiguous src type: " << src.getType() << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "contiguous dst type: " << target.getType() << "\n");

auto sourceMemref = src.getType().cast<MemRefType>();
auto targetMemref = target.getType().cast<MemRefType>();
int64_t srcMemrefOffset = extractOffset(sourceMemref);

Value srcCast;

if (srcMemrefOffset) {
SmallVector<int64_t> strides;
int64_t memrefOffset;
if (failed(getStridesAndOffset(targetMemref, strides, memrefOffset)))
return failure();
srcCast = rewriter.create<memref::ReinterpretCastOp>(
copyOp.getLoc(), targetMemref, src, memrefOffset,
targetMemref.getShape(), strides);
} else
srcCast = rewriter.create<memref::CastOp>(copyOp.getLoc(),
targetMemref, src);
rewriter.replaceAllUsesWith(targetAlloc, {srcCast});
rewriter.eraseOp(copyOp);
return success();
}
}

if (auto srcAlloc = src.getDefiningOp<memref::AllocOp>()) {
Expand Down
26 changes: 17 additions & 9 deletions compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,27 @@ bool isCustomMhloByteirRepeatOp(Operation *op) {
return false;
}

bool isAliasLikeOp(Operation *op) {
if (llvm::isa<mhlo::ReshapeOp>(op)) {
return true;
} else if (auto slice = llvm::dyn_cast_if_present<mhlo::SliceOp>(op)) {
return isSliceContinuousSubview(slice);
}
return false;
}

//===----------------------------------------------------------------------===//
// ElementwiseFusion
//===----------------------------------------------------------------------===//
namespace elementwise {

// TODO: maybe we should support non-splat constant on device in future
bool isFusibleCandidate(Operation *op) {
return isMhlo(op) &&
return isMhlo(op) && !isAliasLikeOp(op) &&
(op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isSplatMhloConstantLike(op) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp>(op) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp>(op) ||
isCustomMhloRngOp(op));
}

Expand Down Expand Up @@ -332,7 +341,11 @@ namespace aggressive_fusion {
bool isFusibleCandidate(Operation *op) {
if (isCustomMhloRngOp(op) || isCustomMhloByteirRepeatOp(op))
return true;
return isMhlo(op) && !llvm::isa<mhlo::CustomCallOp>(op);
if (isAliasLikeOp(op))
return false;
if (llvm::isa<mhlo::CustomCallOp>(op))
return false;
return isMhlo(op);
}

bool isFusibleStart(Operation *) { return true; }
Expand All @@ -347,12 +360,7 @@ bool isFusibleWithNoDenseFuse(Operation *target, Operation * /*start*/) {
target);
}

bool isValidSingleOp(Operation *op) {
if (llvm::isa<mhlo::ReshapeOp>(op))
return false;
else
return true;
}
bool isValidSingleOp(Operation *op) { return true; }

bool isValidFusionPattern(const MhloFusionPattern &) { return true; }

Expand Down
40 changes: 38 additions & 2 deletions compiler/lib/Dialect/mhlo/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ bool mlir::isSplatMhloConstantValue(Value val, double splat_val) {
return isSplatMhloConstantValue(val.getDefiningOp(), splat_val);
}

// Return true if op is a regular reduce/reduce_window op, like reduce
// max/min/sum/any
template <typename RegionOp, typename Op> bool mlir::isRegularReduceOp(Op op) {
if (op.getInputs().size() != 1 || op.getInitValues().size() != 1 ||
op.getResults().size() != 1) {
Expand Down Expand Up @@ -154,6 +156,40 @@ template bool mlir::isRegularReduceOp<mhlo::AddOp, mhlo::ReduceWindowOp>(
template bool mlir::isRegularReduceOp<mhlo::MaxOp, mhlo::ReduceWindowOp>(
mhlo::ReduceWindowOp);

// Return true if slice region is continuous
bool mlir::isSliceContinuousSubview(mhlo::SliceOp op) {
auto type = cast<RankedTensorType>(op.getOperand().getType());
if (!type.hasStaticShape()) {
return false;
}
if (!isSplatValue(op.getStrides(), 1)) {
return false;
}

// find highest non one dimension
std::optional<int64_t> leadingNonOneDimensionIndex;
for (int64_t i = 0; i < type.getRank(); i++) {
if (type.getDimSize(i) != 1) {
leadingNonOneDimensionIndex = i;
break;
}
}
if (!leadingNonOneDimensionIndex.has_value()) {
return false;
}

for (int64_t i = 0; i < type.getRank(); i++) {
if (i != leadingNonOneDimensionIndex.value()) {
if (op.getStartIndices().getValues<int64_t>()[i] != 0 ||
op.getLimitIndices().getValues<int64_t>()[i] != type.getDimSize(i)) {
return false;
}
}
}
return true;
}

// return cumsum's index, return nullopt if not a cumsum op
std::optional<int64_t> mlir::getCumsumIndex(mhlo::ReduceWindowOp op) {
auto base_dilations = op.getBaseDilationsAttr();
if (base_dilations && !isSplatValue(base_dilations, 1)) {
Expand Down Expand Up @@ -182,7 +218,7 @@ std::optional<int64_t> mlir::getCumsumIndex(mhlo::ReduceWindowOp op) {
if (!inputShape.hasRank()) {
return std::nullopt;
}
int64_t index = K_INITIAL;
std::optional<int64_t> index;
for (int64_t i = 0; i < inputShape.getRank(); i++) {
if (window_dimensions[i] == 1 && padding[i * 2] == 0 &&
padding[i * 2 + 1] == 0) {
Expand All @@ -194,7 +230,7 @@ std::optional<int64_t> mlir::getCumsumIndex(mhlo::ReduceWindowOp op) {
} else if (window_dimensions[i] == inputShape.getDimSize(i) &&
padding[i * 2] == inputShape.getDimSize(i) - 1 &&
padding[i * 2 + 1] == 0) {
if (index == K_INITIAL) {
if (!index.has_value()) {
index = i;
} else {
// more than one dim to be cumsumed
Expand Down
4 changes: 3 additions & 1 deletion compiler/lib/Pipelines/ByreHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ void createByreHostPipelineImpl(OpPassManager &pm, const std::string &entryFunc,
// currently use SetOpSpace + SetArgSpace to specify space here
// TODO: later move to GPUOpt after general copy finish
if (!target.empty()) {
pm.addNestedPass<func::FuncOp>(createSetOpSpacePass(entryFunc, target));
// FIXME(chhuang) disable set-op-space here to avoid set discardable attr to
// host side ops, which leads to serialize fail.
// pm.addNestedPass<func::FuncOp>(createSetOpSpacePass(entryFunc, target));
pm.addPass(createSetArgSpacePass(entryFunc, target, true));
}
}
Expand Down
7 changes: 5 additions & 2 deletions compiler/scripts/gen_testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class E2ECollections:
])
def ByreTensorOptPipeline(filecheck, *, entryFunc="main"):
return OptPipeline(E2ECollections.ByreTensorOpt, [E2ECollections.BufferizeOpt], ["-byre-tensor-opt=\"append-arg-types entry-func={}\"".format(entryFunc)], filecheck)
BufferizeOptPipeline = functools.partial(OptPipeline, BufferizeOpt, [AffineOpt, SCFOpt], [
BufferizeOptPipeline = functools.partial(OptPipeline, BufferizeOpt, [SCFOpt], [
"-byteir-bufferize-opt",
])
AffineOptPipeline = functools.partial(OptPipeline, AffineOpt, [GPUOpt], [
Expand Down Expand Up @@ -187,7 +187,10 @@ def emitSingleTestcase(workdir, testcase):
print("===- start processing {} -===".format(workdir))
for i in testcase.contents:
assert isinstance(i, Content), "item in testcase.contents must be a Content"
for s in i.stages:
_stages = i.stages
if isinstance(_stages, Stage):
_stages = [_stages]
for s in _stages:
with workdir.joinpath(s.filename).open("w") as f:
f.write(i.content)

Expand Down
2 changes: 2 additions & 0 deletions compiler/scripts/gen_testcases_and_check_diff.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pushd $CUR_DIR

# CUDA
python3 gen_testcases.py --top-dir ./compiler/test/E2E/CUDA/MLPInference --category=E2E
python3 gen_testcases.py --top-dir ./compiler/test/E2E/CUDA/AliasLikeCUDA --category=E2E
# TODO: add more CUDA E2E checker

# Host
Expand All @@ -18,6 +19,7 @@ python3 gen_testcases.py --top-dir ../test/E2E/Host/RngNormal --category HostPip
python3 gen_testcases.py --top-dir ../test/E2E/Host/RngUniform --category HostPipeline
python3 gen_testcases.py --top-dir ../test/E2E/Host/Transpose --category HostPipeline
python3 gen_testcases.py --top-dir ../test/E2E/Host/TypeCvt --category HostPipeline
python3 gen_testcases.py --top-dir ../test/E2E/Host/AliasLike --category HostPipeline
# Host Bytecode
python3 gen_testcases.py --top-dir ../test/E2E/Host/Case0_Bytecode --category HostPipelineBytecode

Expand Down
Loading

0 comments on commit 5d9c578

Please sign in to comment.