Skip to content

Commit

Permalink
[aievec] Add hoisting patterns for arith.extsi
Browse files Browse the repository at this point in the history
Hoisting cast operations as close as possible to the source of data can
make later patterns more robust to typical variations in the source
code.

We might need to revisit this one if, in the future, this process
causes unintended consequences.
  • Loading branch information
jsetoain committed Aug 31, 2023
1 parent 98eceb3 commit 1bab3df
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 8 deletions.
95 changes: 94 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,79 @@ struct ConvertSplatTransferReadToBroadcastPattern
}
};

// This pattern moves cast operations as close as possible to the source of
// the data. This helps to simplify dealing with patterns that may vary only
// by these sorts of casts between data manipulation operations and arithmetic
// ops.
// TODO: Generalize this op and instantiate for different types of cast ops.
struct HoistCastOpToDataSourcePattern : public RewritePattern {
HoistCastOpToDataSourcePattern(MLIRContext *context)
: RewritePattern(arith::ExtSIOp::getOperationName(), /*benefit=*/1,
context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
Operation *defOp = extOp.getIn().getDefiningOp();
// If it's a data source op, we're done.
if (!defOp ||
isa<vector::TransferReadOp, memref::LoadOp, AffineLoadOp, func::CallOp>(
defOp))
return failure();

// At the moment, we only accept ops we know we can swap with cast.
if (!isa<vector::BroadcastOp, vector::ExtractOp,
vector::ExtractStridedSliceOp>(defOp))
return failure();

Type extOpInTy = extOp.getIn().getType();
SmallVector<Value, 4> inputs;
for (Value operand : defOp->getOperands()) {
Type operandTy = operand.getType();
VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
if (operandTy == extOpInTy) {
Type outTy = extOp.getOut().getType();
inputs.push_back(
rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
.getOut());
} else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
// Promote from vector to scalar -> scalar conversion for this operand
Type outTy =
cast<VectorType>(extOp.getOut().getType()).getElementType();
inputs.push_back(
rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
.getOut());
} else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
// Promote from scalar to vector -> vector conversion for this operand
Type outTy =
VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
inputs.push_back(
rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
.getOut());
} else if (extOpInVecTy && operandVecTy &&
(extOpInVecTy.getElementType() ==
operandVecTy.getElementType())) {
// Hoist through a vector shape change
Type outTy = VectorType::get(
operandVecTy.getShape(),
cast<VectorType>(extOp.getOut().getType()).getElementType());
inputs.push_back(
rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
.getOut());
} else {
inputs.push_back(operand);
}
}

auto newOp =
rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
inputs, {extOp.getOut().getType()}, defOp->getAttrs());
rewriter.replaceOp(extOp, newOp->getResult(0));
return success();
}
};

//============================================================================//
//============ AIEML canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -281,7 +354,7 @@ struct CanonicalizeVectorForAIEVecPass

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, memref::MemRefDialect,
vector::VectorDialect>();
vector::VectorDialect, AffineDialect>();
}

Option<std::string> aieTarget{
Expand Down Expand Up @@ -329,6 +402,24 @@ static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
}

struct HoistCastOpToDataSourcePass
: public PassWrapper<HoistCastOpToDataSourcePass,
OperationPass<func::FuncOp>> {
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<HoistCastOpToDataSourcePattern>(patterns.getContext());

(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
};

static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
return std::make_unique<HoistCastOpToDataSourcePass>();
}

//============================================================================//
//=============== Main Vector2Vector Pipeline Configuration ==================//
//============================================================================//
Expand All @@ -340,4 +431,6 @@ void xilinx::aievec::buildCanonicalizeVectorForAIEVec(
// TODO: Add passes to split vectors that won't fit in registers
pm.addPass(createCopyRemovalPass());
pm.addPass(createCanonicalizeVectorForAIEVecPass(options));

pm.addPass(createHoistCastOpToDataSourcePass());
}
9 changes: 4 additions & 5 deletions test/Conversion/VectorToAIEVec/unaligned-load-aieml.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ func.func @unaligned_read(%a: memref<48xi8>) -> (vector<32xi8>, vector<32xi8>) {
}

// CHECK-LABEL: func @unaligned_read
// CHECK: %[[C2i32:.*]] = arith.constant 2 : i32
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[C16i32:.*]] = arith.constant 16 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2i32:.*]] = arith.constant 2 : i32
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[C16i32:.*]] = arith.constant 16 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T0:.*]] = aievec.upd {{.*}}[%[[C0:.*]]] {index = 0 : i8, offset = 0 : si32} : memref<48xi8>, vector<64xi8>
// CHECK: %[[T0E0:.*]] = aievec.ext %[[T0]] {index = 0 : i8} : vector<64xi8>, vector<32xi8>
// CHECK: %[[T0E1:.*]] = aievec.ext %[[T0]] {index = 1 : i8} : vector<64xi8>, vector<32xi8>
Expand All @@ -22,4 +22,3 @@ func.func @unaligned_read(%a: memref<48xi8>) -> (vector<32xi8>, vector<32xi8>) {
// CHECK: %[[T1E1:.*]] = aievec.ext %[[T1]] {index = 1 : i8} : vector<64xi8>, vector<32xi8>
// CHECK: %[[R1:.*]] = aievec.shift %[[T1E0]], %[[T1E1]], %[[C2i32]] {isAcc = false} : vector<32xi8>, vector<32xi8>, i32, vector<32xi8>
// CHECK: return %[[R0:.*]], %[[R1:.*]] : vector<32xi8>, vector<32xi8>

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// REQUIRES: valid_xchess_license
// RUN: aie-opt %s -convert-vector-to-aievec | aie-translate -aievec-to-cpp -o kernel.tmp.cc
// RUN: echo "#include <cstdint>" > kernel.cc && cat kernel.tmp.cc >> kernel.cc
// RUN: xchesscc_wrapper aie -f -g +s +w work +o work -I%S -I. -I%aietools/include -D__AIENGINE__ kernel.cc %S/helplib.cc %S/main.cc
// RUN: xchesscc_wrapper aie -f -g +s +w work +o work -I%S -I. -I%aietools/include -D__AIENGINE__ -D__AIEARCH__=10 kernel.cc %S/helplib.cc %S/main.cc
// RUN: xca_udm_dbg -qf -T -P %aietools/data/versal_prod/lib -t "%S/../../profiling.tcl ./work/a.out" | FileCheck %s

func.func private @printv16xi16(%v : vector<16xi16>)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// REQUIRES: valid_xchess_license
// RUN: aie-opt %s -convert-vector-to-aievec | aie-translate -aievec-to-cpp -o kernel.tmp.cc
// RUN: echo "#include <cstdint>" > kernel.cc && cat kernel.tmp.cc >> kernel.cc
// RUN: xchesscc_wrapper aie -f -g +s +w work +o work -I%S -I. -I%aietools/include -D__AIENGINE__ kernel.cc %S/helplib.cc %S/main.cc
// RUN: xchesscc_wrapper aie -f -g +s +w work +o work -I%S -I. -I%aietools/include -D__AIENGINE__ -D__AIEARCH__=10 kernel.cc %S/helplib.cc %S/main.cc
// RUN: xca_udm_dbg -qf -T -P %aietools/data/versal_prod/lib -t "%S/../../profiling.tcl ./work/a.out" | FileCheck %s

func.func private @printv8xi32(%v : vector<8xi32>)
Expand Down
41 changes: 41 additions & 0 deletions test/dialect/AIEVec/precanonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,44 @@ func.func @rank_zero_transfer_read(%m : memref<i16>) -> vector<16xi16> {
// CHECK: return %[[S]] : vector<16xi16>
return %v : vector<16xi16>
}

// -----

// CHECK-LABEL: func.func @extsi_hoisting_through_extract_n_bcast(
// CHECK-SAME: %[[VEC:.*]]: vector<16xi8>
func.func @extsi_hoisting_through_extract_n_bcast(%v : vector<16xi8>)
-> vector<32xi32> {
// CHECK: %[[EXV:.*]] = arith.extsi %[[VEC]] : vector<16xi8> to vector<16xi32>
// CHECK: %[[EXS:.*]] = vector.extract %[[EXV]][7] : vector<16xi32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXS]] : i32 to vector<32xi32>
// CHECK: return %[[BCAST]] : vector<32xi32>
%si8 = vector.extract %v[7] : vector<16xi8>
%vi8 = vector.broadcast %si8 : i8 to vector<32xi8>
%vi32 = arith.extsi %vi8 : vector<32xi8> to vector<32xi32>
return %vi32 : vector<32xi32>
}

// -----

// CHECK-LABEL: func.func @extsi_hoisting_through_extract_strided_slice(
// CHECK-SAME: %[[MEM:.*]]: memref<?xi8>
func.func @extsi_hoisting_through_extract_strided_slice(%m : memref<?xi8>)
-> vector<16xi32> {
// CHECK-DAG: %[[C0i8:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
// CHECK: %[[VEC:.*]] = vector.transfer_read %[[MEM]][%[[C0]]], %[[C0i8]] :
// CHECK-SAME: memref<?xi8>, vector<32xi8>
// CHECK: %[[EXV:.*]] = arith.extsi %[[VEC]] : vector<32xi8> to vector<32xi32>
// CHECK: %[[SLC:.*]] = vector.extract_strided_slice %[[EXV]]
// CHECK-SAME: {offsets = [3], sizes = [16], strides = [1]} :
// CHECK-SAME: vector<32xi32> to vector<16xi32>
// CHECK: return %[[SLC]] : vector<16xi32>
%v = vector.transfer_read %m[%c0], %c0_i8 : memref<?xi8>, vector<32xi8>
%slice = vector.extract_strided_slice %v
{offsets = [3], sizes = [16], strides = [1]} :
vector<32xi8> to vector<16xi8>
%vi32 = arith.extsi %slice : vector<16xi8> to vector<16xi32>
return %vi32 : vector<16xi32>
}

0 comments on commit 1bab3df

Please sign in to comment.