diff --git a/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp index 1f5d698ae7..280d11b226 100644 --- a/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp @@ -84,13 +84,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { // Get strides. SmallVector strides = {1, 1}; if (op->hasAttr("strides")) { - // 获取 "strides" 属性 - if (auto attr = op->getAttrOfType("strides")) { - strides.clear(); // 清空默认值 - for (auto value : attr.getValues()) { - strides.push_back(value); - } + if (auto attr = + op->getAttrOfType("strides")) { + strides.clear(); + for (auto value : attr.getValues()) { + strides.push_back(value); } + } } bool stride1 = strides[0] != 1; bool stride2 = strides[1] != 1; @@ -100,16 +100,18 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { // // Get dilations. SmallVector dilations = {1, 1}; if (op->hasAttr("dilations")) { - if (auto attr = op->getAttrOfType("dilations")) { - dilations.clear(); - for (auto value : attr.getValues()) { - dilations.push_back(value); - } + if (auto attr = + op->getAttrOfType("dilations")) { + dilations.clear(); + for (auto value : attr.getValues()) { + dilations.push_back(value); } + } } bool dilated1 = dilations[0] != 1; bool dilated2 = dilations[1] != 1; - Value dilHeight = rewriter.create(loc, dilations[0]); + Value dilHeight = + rewriter.create(loc, dilations[0]); Value dilWidth = rewriter.create(loc, dilations[1]); // Get ElementType of input. @@ -121,7 +123,7 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { const Value c1 = rewriter.create(loc, 1); const Value c2 = rewriter.create(loc, 2); const Value c3 = rewriter.create(loc, 3); - const Value vl_step = rewriter.create(loc, strip); + const Value vlStep = rewriter.create(loc, strip); const Value zero = buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); @@ -139,12 +141,11 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { Value channels = rewriter.create(loc, output, c3); // Calculate the upper bound for vectorized processing - // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. // - Add 1 to ensure the final loop runs when the workload length // is divisible by the vector size. - Value upperBound_tmp = - rewriter.create(loc, channels, vl_step); - Value upperBound = rewriter.create(loc, upperBound_tmp, c1); + Value upperBoundTmp = rewriter.create(loc, channels, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); SmallVector lowerBounds(3, c0); SmallVector uperBounds{batch, height, width}; @@ -153,17 +154,17 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { rewriter, loc, lowerBounds, uperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { // Create strides variables. - Value tmp_ivs1 = ivs[1]; - if(stride1){ - tmp_ivs1 = builder.create(loc, ivs[1], strHeight); + Value tmpIvs1 = ivs[1]; + if (stride1) { + tmpIvs1 = builder.create(loc, ivs[1], strHeight); } - Value tmp_ivs2 = ivs[2]; - if(stride2){ - tmp_ivs2 = builder.create(loc, ivs[2], strWidth); + Value tmpIvs2 = ivs[2]; + if (stride2) { + tmpIvs2 = builder.create(loc, ivs[2], strWidth); } // Create strip mining loop. - auto iter_idx = builder.create( - loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0}, + auto iterIdx = builder.create( + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0}, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { Value outputVector = nestedBuilder.create( @@ -177,12 +178,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { // Create dilated[0] variables. - Value tmp_ivs3 = iv0; - if(dilated1){ - tmp_ivs3 = builder.create(loc, iv0, dilHeight); + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); } Value inputHeight = - builder.create(loc, tmp_ivs1, tmp_ivs3); + builder.create(loc, tmpIvs1, tmpIvs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{kernelWidth}, builder.getDimIdentityMap(), @@ -190,12 +192,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { // Create dilated[1] variables. - Value tmp_ivs4 = iv1; - if(dilated2){ - tmp_ivs4 = builder.create(loc, iv1, dilWidth); + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); } - Value inputWidth = - builder.create(loc, tmp_ivs2, tmp_ivs4); + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); Value inputVector = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, @@ -219,12 +222,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { builder.create( loc, tmp0.getResult(0), output, ValueRange{ivs[0], ivs[1], ivs[2], iv}); - Value idx = builder.create(loc, itrArgs[0], vl_step); + Value idx = + builder.create(loc, itrArgs[0], vlStep); builder.create(loc, idx); }); // Compute the tail size and Process the remaining elements // using masked vector operations. - Value idx = iter_idx.getResult(0); + Value idx = iterIdx.getResult(0); Value tailSize = builder.create(loc, channels, idx); Value tailCond = rewriter.create( loc, arith::CmpIPredicate::sgt, tailSize, c0); @@ -245,12 +249,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { // Create dilated[0] variables. - Value tmp_ivs3 = iv0; - if(dilated1){ - tmp_ivs3 = builder.create(loc, iv0, dilHeight); + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); } Value inputHeight = - builder.create(loc, tmp_ivs1, tmp_ivs3); + builder.create(loc, tmpIvs1, tmpIvs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{kernelWidth}, builder.getDimIdentityMap(), @@ -260,12 +265,13 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { // Calculate the index of the input and // output. // Create dilated[1] variables. - Value tmp_ivs4 = iv1; - if(dilated2){ - tmp_ivs4 = builder.create(loc, iv1, dilWidth); + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = + builder.create(loc, iv1, dilWidth); } Value inputWidth = - builder.create(loc, iv1, tmp_ivs2); + builder.create(loc, iv1, tmpIvs2); // Masked load input and output. Value maskedInputVec = builder.create( loc, vectorTy, input,