Skip to content

Commit

Permalink
[midend/lib/Conversion/ConvVectorization] fix some code style
Browse files Browse the repository at this point in the history
  • Loading branch information
FloatingcloudKnight committed Dec 27, 2024
1 parent 5131c79 commit 40c48a0
Showing 1 changed file with 50 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
// Get strides.
SmallVector<int64_t, 2> strides = {1, 1};
if (op->hasAttr("strides")) {
// 获取 "strides" 属性
if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>("strides")) {
strides.clear(); // 清空默认值
for (auto value : attr.getValues<int64_t>()) {
strides.push_back(value);
}
if (auto attr =
op->getAttrOfType<mlir::DenseIntElementsAttr>("strides")) {
strides.clear();
for (auto value : attr.getValues<int64_t>()) {
strides.push_back(value);
}
}
}
bool stride1 = strides[0] != 1;
bool stride2 = strides[1] != 1;
Expand All @@ -100,16 +100,18 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
// // Get dilations.
SmallVector<int64_t, 2> dilations = {1, 1};
if (op->hasAttr("dilations")) {
if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations")) {
dilations.clear();
for (auto value : attr.getValues<int64_t>()) {
dilations.push_back(value);
}
if (auto attr =
op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations")) {
dilations.clear();
for (auto value : attr.getValues<int64_t>()) {
dilations.push_back(value);
}
}
}
bool dilated1 = dilations[0] != 1;
bool dilated2 = dilations[1] != 1;
Value dilHeight = rewriter.create<arith::ConstantIndexOp>(loc, dilations[0]);
Value dilHeight =
rewriter.create<arith::ConstantIndexOp>(loc, dilations[0]);
Value dilWidth = rewriter.create<arith::ConstantIndexOp>(loc, dilations[1]);

// Get ElementType of input.
Expand All @@ -121,7 +123,7 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
const Value c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
const Value c2 = rewriter.create<arith::ConstantIndexOp>(loc, 2);
const Value c3 = rewriter.create<arith::ConstantIndexOp>(loc, 3);
const Value vl_step = rewriter.create<arith::ConstantIndexOp>(loc, strip);
const Value vlStep = rewriter.create<arith::ConstantIndexOp>(loc, strip);
const Value zero =
buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy);

Expand All @@ -139,12 +141,11 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
Value channels = rewriter.create<memref::DimOp>(loc, output, c3);

// Calculate the upper bound for vectorized processing
// - Subtract `vl_step` is to avoid overflow at the vectorization tail.
// - Subtract `vlStep` is to avoid overflow at the vectorization tail.
// - Add 1 to ensure the final loop runs when the workload length
// is divisible by the vector size.
Value upperBound_tmp =
rewriter.create<arith::SubIOp>(loc, channels, vl_step);
Value upperBound = rewriter.create<arith::AddIOp>(loc, upperBound_tmp, c1);
Value upperBoundTmp = rewriter.create<arith::SubIOp>(loc, channels, vlStep);
Value upperBound = rewriter.create<arith::AddIOp>(loc, upperBoundTmp, c1);

SmallVector<Value, 8> lowerBounds(3, c0);
SmallVector<Value, 8> uperBounds{batch, height, width};
Expand All @@ -153,17 +154,17 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
rewriter, loc, lowerBounds, uperBounds, steps,
[&](OpBuilder &builder, Location loc, ValueRange ivs) {
// Create strides variables.
Value tmp_ivs1 = ivs[1];
if(stride1){
tmp_ivs1 = builder.create<arith::MulIOp>(loc, ivs[1], strHeight);
Value tmpIvs1 = ivs[1];
if (stride1) {
tmpIvs1 = builder.create<arith::MulIOp>(loc, ivs[1], strHeight);
}
Value tmp_ivs2 = ivs[2];
if(stride2){
tmp_ivs2 = builder.create<arith::MulIOp>(loc, ivs[2], strWidth);
Value tmpIvs2 = ivs[2];
if (stride2) {
tmpIvs2 = builder.create<arith::MulIOp>(loc, ivs[2], strWidth);
}
// Create strip mining loop.
auto iter_idx = builder.create<scf::ForOp>(
loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0},
auto iterIdx = builder.create<scf::ForOp>(
loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0},
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
ValueRange itrArgs) {
Value outputVector = nestedBuilder.create<vector::LoadOp>(
Expand All @@ -177,25 +178,27 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, Value iv0,
ValueRange itrArgs0) {
// Create dilated[0] variables.
Value tmp_ivs3 = iv0;
if(dilated1){
tmp_ivs3 = builder.create<arith::MulIOp>(loc, iv0, dilHeight);
Value tmpIvs3 = iv0;
if (dilated1) {
tmpIvs3 =
builder.create<arith::MulIOp>(loc, iv0, dilHeight);
}
Value inputHeight =
builder.create<arith::AddIOp>(loc, tmp_ivs1, tmp_ivs3);
builder.create<arith::AddIOp>(loc, tmpIvs1, tmpIvs3);
auto tmp1 = builder.create<affine::AffineForOp>(
loc, ValueRange{c0}, builder.getDimIdentityMap(),
ValueRange{kernelWidth}, builder.getDimIdentityMap(),
/*Step=*/1, ValueRange{itrArgs0[0]},
[&](OpBuilder &builder, Location loc, Value iv1,
ValueRange itrArgs1) {
// Create dilated[1] variables.
Value tmp_ivs4 = iv1;
if(dilated2){
tmp_ivs4 = builder.create<arith::MulIOp>(loc, iv1, dilWidth);
Value tmpIvs4 = iv1;
if (dilated2) {
tmpIvs4 = builder.create<arith::MulIOp>(loc, iv1,
dilWidth);
}
Value inputWidth =
builder.create<arith::AddIOp>(loc, tmp_ivs2, tmp_ivs4);
Value inputWidth = builder.create<arith::AddIOp>(
loc, tmpIvs2, tmpIvs4);
Value inputVector = builder.create<vector::LoadOp>(
loc, vectorTy, input,
ValueRange{ivs[0], inputHeight, inputWidth,
Expand All @@ -219,12 +222,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
builder.create<vector::StoreOp>(
loc, tmp0.getResult(0), output,
ValueRange{ivs[0], ivs[1], ivs[2], iv});
Value idx = builder.create<arith::AddIOp>(loc, itrArgs[0], vl_step);
Value idx =
builder.create<arith::AddIOp>(loc, itrArgs[0], vlStep);
builder.create<scf::YieldOp>(loc, idx);
});
// Compute the tail size and Process the remaining elements
// using masked vector operations.
Value idx = iter_idx.getResult(0);
Value idx = iterIdx.getResult(0);
Value tailSize = builder.create<arith::SubIOp>(loc, channels, idx);
Value tailCond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, tailSize, c0);
Expand All @@ -245,12 +249,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, Value iv0,
ValueRange itrArgs0) {
// Create dilated[0] variables.
Value tmp_ivs3 = iv0;
if(dilated1){
tmp_ivs3 = builder.create<arith::MulIOp>(loc, iv0, dilHeight);
Value tmpIvs3 = iv0;
if (dilated1) {
tmpIvs3 =
builder.create<arith::MulIOp>(loc, iv0, dilHeight);
}
Value inputHeight =
builder.create<arith::AddIOp>(loc, tmp_ivs1, tmp_ivs3);
builder.create<arith::AddIOp>(loc, tmpIvs1, tmpIvs3);
auto tmp1 = builder.create<affine::AffineForOp>(
loc, ValueRange{c0}, builder.getDimIdentityMap(),
ValueRange{kernelWidth}, builder.getDimIdentityMap(),
Expand All @@ -260,12 +265,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
// Calculate the index of the input and
// output.
// Create dilated[1] variables.
Value tmp_ivs4 = iv1;
if(dilated2){
tmp_ivs4 = builder.create<arith::MulIOp>(loc, iv1, dilWidth);
Value tmpIvs4 = iv1;
if (dilated2) {
tmpIvs4 =
builder.create<arith::MulIOp>(loc, iv1, dilWidth);
}
Value inputWidth =
builder.create<arith::AddIOp>(loc, iv1, tmp_ivs2);
builder.create<arith::AddIOp>(loc, iv1, tmpIvs2);
// Masked load input and output.
Value maskedInputVec = builder.create<MaskedLoadOp>(
loc, vectorTy, input,
Expand Down

0 comments on commit 40c48a0

Please sign in to comment.