diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp index 23d0ef4e7b..f84aeb2223 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp @@ -77,6 +77,8 @@ class MatMulParallelVectorizationPattern : public ConversionPattern { // Define constants. const Value zeroIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value oneIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); const AffineExpr d0 = rewriter.getAffineDimExpr(0); const AffineExpr d1 = rewriter.getAffineDimExpr(1); const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); @@ -88,31 +90,17 @@ class MatMulParallelVectorizationPattern : public ConversionPattern { loc, VectorType::get({affineVectorSize}, elementType), zeroElementType); // Get dimensions of input tensors. - Value aRow = rewriter.create(loc, A, 0); - Value bCol = rewriter.create(loc, B, 1); - Value bRow = rewriter.create(loc, B, 0); - - // Calculate the length of the tail, which might not fit in a vector. - Value tailLength = rewriter.create( - loc, AffineMap::get(1, 0, d0 % affineVectorSize), ValueRange{bCol}); - - // Generate a mask vector based on the tail length. - Value maskVector = rewriter.create( - loc, VectorType::get({affineVectorSize}, rewriter.getI1Type()), - ValueRange{tailLength}); + Value aRow = rewriter.create(loc, A, zeroIndex); + Value bCol = rewriter.create(loc, B, oneIndex); + Value bRow = rewriter.create(loc, B, zeroIndex); SmallVector reducedValues = llvm::to_vector<4>( llvm::map_range(ArrayRef{}, [](const LoopReduction &red) { return red.value; })); - // Apply the column of matrix B. - Value appliedColOfB = rewriter.create( - loc, AffineMap::get(1, 0, d0.ceilDiv(affineVectorSize)), - ValueRange{bCol}); - // Create the primary parallel loop for matrix multiplication. AffineParallelOp parallelLoop = rewriter.create( - loc, ValueRange(reducedValues).getTypes(), ValueRange{appliedColOfB}, + loc, ValueRange(reducedValues).getTypes(), ValueRange{bCol}, ArrayRef{ rewriter.getNamedAttr("lowerBoundsGroups", rewriter.getI32TensorAttr({1})), @@ -126,7 +114,7 @@ class MatMulParallelVectorizationPattern : public ConversionPattern { AffineMapAttr::get(AffineMap::get( 1, 0, {d0}, rewriter.getContext()))), rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), - rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); + rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({affineVectorSize}))}); // Create the loop body for the parallel loop. Block *loopBody = new Block(); @@ -147,35 +135,37 @@ class MatMulParallelVectorizationPattern : public ConversionPattern { affine::AffineIfOp branchingOp = rewriter.create( loc, IntegerSet::get( - 1, 1, {d0 * -affineVectorSize + s0 - affineVectorSize}, {false}), + 1, 1, {s0 - d0 - affineVectorSize}, {false}), ValueRange{loopVarColOfB, bCol}, true); // Branch handling full vector operations. OpBuilder trueBranchBuilder = branchingOp.getThenBodyBuilder(); - affine::buildAffineLoopNest( - trueBranchBuilder, loc, {zeroIndex}, {bRow}, 1, + affine::buildAffineLoopNest( + trueBranchBuilder, loc, {zeroIndex}, {aRow}, 1, [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, + Value loopVarRowOfA = ivRange.front(); + Value cVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), C, + AffineMap::get(2, 0, {d0, d1}, rewriter.getContext()), - ValueRange{loopVarRowOfB, loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, ValueRange{loopVarRowOfA, loopVarRowOfB}); - Value aVec = builder.create( + ValueRange{loopVarRowOfA, loopVarColOfB}); + auto iter_vec = builder.create( + loc, ValueRange{zeroIndex}, builder.getDimIdentityMap(), + ValueRange{bRow}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{cVec}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs0){ + Value bVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), B, + AffineMap::get(2, 0, {d0, d1}, + rewriter.getContext()), + ValueRange{iv1, loopVarColOfB}); + Value aElement = builder.create( + loc, A, ValueRange{loopVarRowOfA, iv1}); + Value aVec = builder.create( loc, VectorType::get({affineVectorSize}, elementType), aElement); - Value cVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarRowOfA, loopVarColOfB}); - Value computedVec; + Value computedVec; // Compute the result vector either through integer // multiplication and addition or fused multiply-add @@ -184,109 +174,124 @@ class MatMulParallelVectorizationPattern : public ConversionPattern { Value mulVec = builder.create(loc, aVec, bVec); computedVec = - builder.create(loc, mulVec, cVec); + builder.create(loc, mulVec, itrArgs0[0]); } else { computedVec = - builder.create(loc, aVec, bVec, cVec); - } + builder.create(loc, aVec, bVec, itrArgs0[0]); + } + builder.create(loc, computedVec); + }); builder.create( - loc, computedVec, C, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarRowOfA, loopVarColOfB}); - }); - }); + loc, iter_vec.getResult(0), C, + AffineMap::get(2, 0, {d0, d1}, + builder.getContext()), + ValueRange{loopVarRowOfA, loopVarColOfB}); + }); // Branch handling operations on the tail. OpBuilder falseBranchBuilder = branchingOp.getElseBodyBuilder(); - affine::buildAffineLoopNest( - falseBranchBuilder, loc, {zeroIndex}, {bRow}, 1, + Value tailSize = falseBranchBuilder.create(loc, bCol, loopVarColOfB); + Value maskVector = falseBranchBuilder.create( + loc, VectorType::get({affineVectorSize}, rewriter.getI1Type()), + ValueRange{tailSize}); + affine::buildAffineLoopNest( + falseBranchBuilder, loc, {zeroIndex}, {aRow}, 1, [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value tailIdxColOfB = builder.create( - loc, AffineMap::get(1, 0, d0 * affineVectorSize), - ValueRange{loopVarColOfB}); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - ValueRange{loopVarRowOfB, tailIdxColOfB}, maskVector, - zeroElementTypeVec); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, ValueRange{loopVarRowOfA, loopVarRowOfB}); - Value aVec = builder.create( + Value loopVarRowOfA = ivRange.front(); + Value cVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), C, + ValueRange{loopVarRowOfA, loopVarColOfB}, maskVector, + zeroElementTypeVec); + + auto iter_vec = builder.create( + loc, ValueRange{zeroIndex}, builder.getDimIdentityMap(), + ValueRange{bRow}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{cVec}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs0){ + // Value bVec = builder.create( + // loc, VectorType::get({affineVectorSize}, elementType), B, + // AffineMap::get(2, 0, {d0, d1}, + // rewriter.getContext()), + // ValueRange{iv1, loopVarColOfB}); + Value bVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), B, + ValueRange{iv1, loopVarColOfB}, maskVector, + zeroElementTypeVec); + + Value aElement = builder.create( + loc, A, ValueRange{loopVarRowOfA, iv1}); + Value aVec = builder.create( loc, VectorType::get({affineVectorSize}, elementType), aElement); - Value cVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), C, - ValueRange{loopVarRowOfA, tailIdxColOfB}, maskVector, - zeroElementTypeVec); - Value computedVec; + Value computedVec; // Compute the result vector either through integer - // multiplication and addition or fused multiply-add based on - // the element type. + // multiplication and addition or fused multiply-add + // based on the element type. if (isa(elementType)) { Value mulVec = builder.create(loc, aVec, bVec); computedVec = - builder.create(loc, mulVec, cVec); + builder.create(loc, mulVec, itrArgs0[0]); } else { computedVec = - builder.create(loc, aVec, bVec, cVec); - } + builder.create(loc, aVec, bVec, itrArgs0[0]); + } + builder.create(loc, computedVec); + }); builder.create( - loc, C, ValueRange{loopVarRowOfA, tailIdxColOfB}, - maskVector, computedVec); - }); - }); + loc, C, ValueRange{loopVarRowOfA, loopVarColOfB}, + maskVector, iter_vec.getResult(0)); + }); } else { - affine::buildAffineLoopNest( - rewriter, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, - rewriter.getContext()), - ValueRange{loopVarRowOfB, loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); + affine::buildAffineLoopNest( + rewriter, loc, {zeroIndex}, {aRow}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value loopVarRowOfA = ivRange.front(); + Value cVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), C, + AffineMap::get(2, 0, {d0, d1}, + rewriter.getContext()), + ValueRange{loopVarRowOfA, loopVarColOfB}); + auto iter_vec = builder.create( + loc, ValueRange{zeroIndex}, builder.getDimIdentityMap(), + ValueRange{bRow}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{cVec}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs0){ + Value bVec = builder.create( + loc, VectorType::get({affineVectorSize}, elementType), B, + AffineMap::get(2, 0, {d0, d1}, + rewriter.getContext()), + ValueRange{iv1, loopVarColOfB}); Value aElement = builder.create( - loc, A, ValueRange{loopVarRowOfA, loopVarRowOfB}); + loc, A, ValueRange{loopVarRowOfA, iv1}); Value aVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarRowOfA, loopVarColOfB}); + loc, VectorType::get({affineVectorSize}, elementType), + aElement); Value computedVec; - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = - builder.create(loc, aVec, bVec, cVec); - } - builder.create( - loc, computedVec, C, - AffineMap::get(2, 0, {d0, d1 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarRowOfA, loopVarColOfB}); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + if (isa(elementType)) { + Value mulVec = + builder.create(loc, aVec, bVec); + computedVec = + builder.create(loc, mulVec, itrArgs0[0]); + } else { + computedVec = + builder.create(loc, aVec, bVec, itrArgs0[0]); + } + builder.create(loc, computedVec); + }); + builder.create( + loc, iter_vec.getResult(0), C, + AffineMap::get(2, 0, {d0, d1}, + builder.getContext()), + ValueRange{loopVarRowOfA, loopVarColOfB}); }); - }); } rewriter.create(loc);