Skip to content

Commit

Permalink
Merge branch 'main' into tf-frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
heromapwrd authored Jun 11, 2024
2 parents 2eacc91 + 5ee8245 commit 236b2c4
Show file tree
Hide file tree
Showing 151 changed files with 855 additions and 912 deletions.
4 changes: 2 additions & 2 deletions compiler/dialects/lib/Dialect/Ace/IR/AceDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ OpFoldResult mlir::ace::ConstOp::fold(FoldAdaptor) { return getValue(); }
//===----------------------------------------------------------------------===//

LogicalResult mlir::ace::ReshapeOp::verify() {
auto operandTy = getOperand().getType().dyn_cast<RankedTensorType>();
auto operandTy = dyn_cast<RankedTensorType>(getOperand().getType());
// If the operand type is dynamically shaped there is nothing to verify.
if (!operandTy || !operandTy.hasStaticShape())
return success();

// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto resultTy = getResult().getType().cast<RankedTensorType>();
auto resultTy = cast<RankedTensorType>(getResult().getType());
assert(resultTy && resultTy.hasStaticShape() &&
"result type must be statically shaped");
int64_t numResultElements = resultTy.getNumElements();
Expand Down
2 changes: 1 addition & 1 deletion compiler/dialects/lib/Dialect/Ccl/IR/CclOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ verifyReplicaGroups(std::optional<Location> location,
"dynamic_replica_groups and replica_groups can't exist simultaneously");

if (dynamicReplicaGroups != nullptr) {
ShapedType type = dynamicReplicaGroups.getType().cast<ShapedType>();
ShapedType type = cast<ShapedType>(dynamicReplicaGroups.getType());
if (!type.getElementType().isa<IndexType, IntegerType>())
return emitOptionalError(
location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getRank();
return 0;
}]
Expand All @@ -373,7 +373,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getShape();
return {};
}]
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/byteir/Dialect/Linalg/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def LinalgExt_BatchMatmulOp : LinalgExtStructuredBase_Op<"batch_matmul",

// Additional functions
int64_t getFullRank() {
return getInit().getType().cast<ShapedType>().getRank() + 1;
return cast<ShapedType>(getInit().getType()).getRank() + 1;
}

}];
Expand Down
4 changes: 2 additions & 2 deletions compiler/include/byteir/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ bool isSplatValue(DenseFPElementsAttr attr, double value);
inline bool isSplatElementsAttribute(DenseIntOrFPElementsAttr attr,
int64_t intValue, double doubleValue) {
if (attr.isa<DenseIntElementsAttr>()) {
return isSplatValue(attr.cast<DenseIntElementsAttr>(), intValue);
return isSplatValue(cast<DenseIntElementsAttr>(attr), intValue);
} else if (attr.isa<DenseFPElementsAttr>()) {
return isSplatValue(attr.cast<DenseFPElementsAttr>(), doubleValue);
return isSplatValue(cast<DenseFPElementsAttr>(attr), doubleValue);
}
assert(false && "attr must be DenseIntElementsAttr or DenseFPElementsAttr");
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/lib/Analysis/Liveness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
if (Operation *defOp = value.getDefiningOp())
currentBlock = defOp->getBlock();
else
currentBlock = value.cast<BlockArgument>().getOwner();
currentBlock = cast<BlockArgument>(value).getOwner();
toProcess.push_back(currentBlock);
visited.insert(currentBlock);

Expand Down Expand Up @@ -314,7 +314,7 @@ void Liveness::print(raw_ostream &os) const {
if (value.getDefiningOp())
os << "val_" << valueIds[value];
else {
auto blockArg = value.cast<BlockArgument>();
auto blockArg = cast<BlockArgument>(value);
os << "arg" << blockArg.getArgNumber() << "@"
<< blockIds[blockArg.getOwner()];
}
Expand Down
24 changes: 11 additions & 13 deletions compiler/lib/Analysis/ShapeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace shape_analysis {

ValueKnowledge ValueKnowledge::getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState();
if (auto shapedType = type.dyn_cast_or_null<ShapedType>()) {
if (auto shapedType = dyn_cast_or_null<ShapedType>(type)) {
if (shapedType.hasRank()) {
result.hasRank = true;
result.sizes.reserve(shapedType.getRank());
Expand Down Expand Up @@ -273,9 +273,9 @@ LogicalResult ShapeAnalysis::inferResultShapesWithKnowledges(
}
if (knowledge) {
for (auto &&resultType : op->getResultTypes()) {
if (auto shapedType = resultType.dyn_cast_or_null<ShapedType>()) {
if (auto shapedType = dyn_cast_or_null<ShapedType>(resultType)) {
knowledge.dtype = shapedType.getElementType();
results.push_back(knowledge.getType().cast<ShapedType>());
results.push_back(cast<ShapedType>(knowledge.getType()));
} else {
results.push_back(ShapedTypeComponents{});
}
Expand Down Expand Up @@ -318,7 +318,7 @@ LogicalResult ShapeAnalysis::inferResultShapesWithKnowledges(
.succeeded()) {
results.assign(llvm::to_vector(llvm::map_range(
inferredType, [](mlir::Type t) -> ShapedTypeComponents {
if (auto st = t.dyn_cast_or_null<ShapedType>())
if (auto st = dyn_cast_or_null<ShapedType>(t))
return st;
return {};
})));
Expand Down Expand Up @@ -401,7 +401,7 @@ void ShapeAnalysis::visitOperation(Operation *op,

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
Expand Down Expand Up @@ -464,12 +464,12 @@ void ShapeValueAnalysis::visitOperation(
return;
}
auto inputType =
shapeLattice->getValue().getType().dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(shapeLattice->getValue().getType());
if (!inputType || !inputType.hasStaticShape()) {
return setAllToEntryStates(results);
}
auto shape = inputType.getShape();
auto outType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto outType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
auto resultAttr = DenseIntElementsAttr::get(outType, shape);
auto lattice = results[0];
propagateIfChanged(lattice, lattice->join(ConstantValue(
Expand All @@ -487,13 +487,11 @@ void ShapeValueAnalysis::visitOperation(
return;
}
Attribute constAttr = index->getValue().getConstantValue();
if (auto denseInt =
constAttr.dyn_cast_or_null<DenseIntElementsAttr>()) {
if (auto denseInt = dyn_cast_or_null<DenseIntElementsAttr>(constAttr)) {

auto newType = denseInt.getType().clone(cast<arith::IndexCastOp>(op)
.getType()
.cast<RankedTensorType>()
.getElementType());
auto newType = denseInt.getType().clone(
cast<RankedTensorType>(cast<arith::IndexCastOp>(op).getType())
.getElementType());

SmallVector<APInt> newDenseInt;
uint32_t width;
Expand Down
4 changes: 2 additions & 2 deletions compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {

mlir::Type resultType = castedOperands.front().getType();
mlir::Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName = getFunctionName(
funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
StringRef funcName =
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
if (funcName.empty())
return failure();

Expand Down
45 changes: 21 additions & 24 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreCustom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ ByreCustomConfig mlir::getCudaByreCustomConfig() {
ShapedType vShapeTy;
ShapedType oShapeTy;
if (callee == getFlashAttnFwdName()) {
qShapeTy = op.getOperand(0).getType().dyn_cast<ShapedType>();
kShapeTy = op.getOperand(1).getType().dyn_cast<ShapedType>();
vShapeTy = op.getOperand(2).getType().dyn_cast<ShapedType>();
oShapeTy = op.getResult(0).getType().dyn_cast<ShapedType>();
qShapeTy = dyn_cast<ShapedType>(op.getOperand(0).getType());
kShapeTy = dyn_cast<ShapedType>(op.getOperand(1).getType());
vShapeTy = dyn_cast<ShapedType>(op.getOperand(2).getType());
oShapeTy = dyn_cast<ShapedType>(op.getResult(0).getType());
} else {
qShapeTy = op.getOperand(1).getType().dyn_cast<ShapedType>();
kShapeTy = op.getOperand(2).getType().dyn_cast<ShapedType>();
vShapeTy = op.getOperand(3).getType().dyn_cast<ShapedType>();
oShapeTy = op.getOperand(4).getType().dyn_cast<ShapedType>();
qShapeTy = dyn_cast<ShapedType>(op.getOperand(1).getType());
kShapeTy = dyn_cast<ShapedType>(op.getOperand(2).getType());
vShapeTy = dyn_cast<ShapedType>(op.getOperand(3).getType());
oShapeTy = dyn_cast<ShapedType>(op.getOperand(4).getType());
}
if (!qShapeTy || !qShapeTy.hasStaticShape() || !kShapeTy ||
!kShapeTy.hasStaticShape() || !vShapeTy ||
Expand Down Expand Up @@ -128,16 +128,14 @@ ByreCustomConfig mlir::getCudaByreCustomConfig() {
uint32_t oHeadStride = oShape[3];

DictionaryAttr byteirAttrs =
op->getAttr(getCustomCallAttrName()).cast<DictionaryAttr>();
cast<DictionaryAttr>(op->getAttr(getCustomCallAttrName()));
if (!byteirAttrs)
assert(false && "byteir attribute not found!");
bool causal = byteirAttrs.get("causal").cast<BoolAttr>().getValue();
float softmaxScale = byteirAttrs.get("softmax_scale")
.cast<FloatAttr>()
bool causal = cast<BoolAttr>(byteirAttrs.get("causal")).getValue();
float softmaxScale = cast<FloatAttr>(byteirAttrs.get("softmax_scale"))
.getValue()
.convertToDouble();
float dropoutP = byteirAttrs.get("dropout_p")
.cast<FloatAttr>()
float dropoutP = cast<FloatAttr>(byteirAttrs.get("dropout_p"))
.getValue()
.convertToDouble();
int windowSizeLeft = -1;
Expand Down Expand Up @@ -178,14 +176,14 @@ ByreCustomConfig mlir::getCudaByreCustomConfig() {
return ArrayAttr::get(rewriter.getContext(), extraArgs);
} else if (callee == getFlashAttnKVCacheName()) {
OpBuilder rewriter(op);
ShapedType qShapeTy = op.getOperand(0).getType().dyn_cast<ShapedType>();
ShapedType qShapeTy = dyn_cast<ShapedType>(op.getOperand(0).getType());
ShapedType kcacheShapeTy =
op.getOperand(1).getType().dyn_cast<ShapedType>();
dyn_cast<ShapedType>(op.getOperand(1).getType());
ShapedType vcacheShapeTy =
op.getOperand(2).getType().dyn_cast<ShapedType>();
ShapedType kShapeTy = op.getOperand(3).getType().dyn_cast<ShapedType>();
ShapedType vShapeTy = op.getOperand(4).getType().dyn_cast<ShapedType>();
ShapedType oShapeTy = op.getResult(0).getType().dyn_cast<ShapedType>();
dyn_cast<ShapedType>(op.getOperand(2).getType());
ShapedType kShapeTy = dyn_cast<ShapedType>(op.getOperand(3).getType());
ShapedType vShapeTy = dyn_cast<ShapedType>(op.getOperand(4).getType());
ShapedType oShapeTy = dyn_cast<ShapedType>(op.getResult(0).getType());
if (!qShapeTy || !qShapeTy.hasStaticShape() || !kShapeTy ||
!kShapeTy.hasStaticShape() || !vShapeTy ||
!vShapeTy.hasStaticShape() || !kcacheShapeTy ||
Expand Down Expand Up @@ -239,12 +237,11 @@ ByreCustomConfig mlir::getCudaByreCustomConfig() {
uint32_t oHeadStride = oShape[3];

DictionaryAttr byteirAttrs =
op->getAttr(getCustomCallAttrName()).cast<DictionaryAttr>();
cast<DictionaryAttr>(op->getAttr(getCustomCallAttrName()));
if (!byteirAttrs)
assert(false && "byteir attribute not found!");
bool causal = byteirAttrs.get("causal").cast<BoolAttr>().getValue();
float softmaxScale = byteirAttrs.get("softmax_scale")
.cast<FloatAttr>()
bool causal = cast<BoolAttr>(byteirAttrs.get("causal")).getValue();
float softmaxScale = cast<FloatAttr>(byteirAttrs.get("softmax_scale"))
.getValue()
.convertToDouble();
int windowSizeLeft = -1;
Expand Down
26 changes: 13 additions & 13 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ template <typename OP> class ConvertReshapeOp : public OpConversionPattern<OP> {
matchAndRewrite(OP op, typename OP::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto operand = adaptor.getOperand();
auto operandType = operand.getType().template cast<ShapedType>();
auto resultType = op.getType().template cast<ShapedType>();
auto operandType = llvm::cast<ShapedType>(operand.getType());
auto resultType = llvm::cast<ShapedType>(op.getType());

if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
Expand Down Expand Up @@ -199,7 +199,7 @@ class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
matchAndRewrite(mhlo::SliceOp sliceOp,
typename mhlo::SliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto argType = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
auto argType = dyn_cast<ShapedType>(adaptor.getOperands()[0].getType());
if (!argType || !argType.hasRank()) {
return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args");
}
Expand Down Expand Up @@ -238,7 +238,7 @@ class ConvertConcatenateOp : public OpConversionPattern<mhlo::ConcatenateOp> {

uint64_t axis = concatOp.getDimension();
if (llvm::any_of(adaptor.getOperands(), [&](auto &&value) {
return value.getType().template cast<ShapedType>().isDynamicDim(axis);
return cast<ShapedType>(value.getType()).isDynamicDim(axis);
}))
return failure();

Expand Down Expand Up @@ -266,7 +266,7 @@ class ConvertConcatenateOp : public OpConversionPattern<mhlo::ConcatenateOp> {
resultType, dynDims);
int64_t upperBound = 0;
for (auto &&operand : adaptor.getOperands()) {
auto operandType = operand.getType().cast<ShapedType>();
auto operandType = cast<ShapedType>(operand.getType());
static_offsets[axis] = upperBound;
static_sizes[axis] = operandType.getDimSize(axis);
value = rewriter.create<tensor::InsertSliceOp>(
Expand All @@ -289,13 +289,13 @@ class ConvertGatherOpToByrePattern
matchAndRewrite(mhlo::GatherOp op, typename mhlo::GatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto startIndices = op.getStartIndices();
auto startIndicesTy = startIndices.getType().cast<ShapedType>();
auto startIndicesTy = cast<ShapedType>(startIndices.getType());
if (!startIndicesTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "unranked start_indices");
}

auto operand = op.getOperand();
auto operandTy = operand.getType().cast<ShapedType>();
auto operandTy = cast<ShapedType>(operand.getType());
if (!operandTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "unranked operand");
}
Expand All @@ -320,7 +320,7 @@ class ConvertGatherOpToByrePattern
return rewriter.notifyMatchFailure(op, "start_index_map != [0]");
}

auto resultTy = op.getResult().getType().dyn_cast<ShapedType>();
auto resultTy = dyn_cast<ShapedType>(op.getResult().getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op, "unranked result");
}
Expand Down Expand Up @@ -423,8 +423,8 @@ class ConvertDotOpToByrePattern : public OpConversionPattern<mhlo::DotOp> {
matchAndRewrite(mlir::mhlo::DotOp op, mlir::mhlo::DotOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: support matrix * vector, vector * matrix and vector * vector
if (adaptor.getLhs().getType().cast<ShapedType>().getRank() != 2 ||
adaptor.getRhs().getType().cast<ShapedType>().getRank() != 2)
if (cast<ShapedType>(adaptor.getLhs().getType()).getRank() != 2 ||
cast<ShapedType>(adaptor.getRhs().getType()).getRank() != 2)
return failure();

auto failureOrComputeOnTensorOp = replaceMhloOpWithByreComputeOnTensorOp(
Expand Down Expand Up @@ -483,7 +483,7 @@ class ConvertDotGeneralOpToByrePattern
// convert to BatchMatmulOp
SmallVector<int64_t> batchingDimensions;
for (int64_t i = 0,
e = op->getResult(0).getType().cast<ShapedType>().getRank();
e = cast<ShapedType>(op->getResult(0).getType()).getRank();
i < e - 2; i++) {
batchingDimensions.push_back(i);
}
Expand Down Expand Up @@ -596,7 +596,7 @@ class ConvertReduceOpToByrePattern
return rewriter.notifyMatchFailure(op, "unsupported block in reduce");
}

auto inputShape = adaptor.getInputs()[0].getType().dyn_cast<ShapedType>();
auto inputShape = dyn_cast<ShapedType>(adaptor.getInputs()[0].getType());
if (!inputShape || !inputShape.hasRank()) {
return rewriter.notifyMatchFailure(op, "invalid input type");
}
Expand Down Expand Up @@ -648,7 +648,7 @@ class ConvertReduceWindowOpToByrePattern
return rewriter.notifyMatchFailure(
op, "batched reductions is not supported yet");
}
auto inputShape = adaptor.getInputs()[0].getType().dyn_cast<ShapedType>();
auto inputShape = dyn_cast<ShapedType>(adaptor.getInputs()[0].getType());
if (!inputShape || !inputShape.hasRank()) {
return rewriter.notifyMatchFailure(op, "invalid input type");
}
Expand Down
Loading

0 comments on commit 236b2c4

Please sign in to comment.