Skip to content

Commit

Permalink
[LAYOUTS] Implement getShapePerCTATile via LLs (#6026)
Browse files Browse the repository at this point in the history
This function should actually be removed once we have a proper lowering
of SIMD instructions via divideLeft.
  • Loading branch information
lezcano authored Feb 28, 2025
1 parent 37442c3 commit 161e713
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 96 deletions.
5 changes: 3 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
};

// Convert a distributed layout to a linear encoding
LinearEncodingAttr toLinearEncoding(RankedTensorType type);
LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape);

unsigned getTotalElemsPerThread(Type type);
Expand Down Expand Up @@ -181,7 +182,8 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wrapped to fit ShapePerCTA.
*/
SmallVector<unsigned> getShapePerCTATile(Attribute layout);
// [FIXME LL] Kill this function
SmallVector<unsigned> getShapePerCTATile(RankedTensorType layout);

// Returns the "logical" shape per CTA
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
Expand Down Expand Up @@ -238,7 +240,6 @@ llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
// If skipBroadcast is false, we count a base zero
SmallVector<unsigned> basesPerDim(StringAttr dimName,
bool skipBroadcast = true) const;

// [FIXME LL] Supports legacy behaviour. We should remove this function.
SmallVector<unsigned> getShapePerCTATile() const;
}];

let genVerifyDecl = 1;
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,

auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcTy);
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstTy);

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");
Expand Down
13 changes: 9 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,16 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto sizePerThread = getContigPerThread(dTensorTy);
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
SmallVector<unsigned> shapePerCTATile;
for (auto [reg, thread, warp] :
llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(),
dLayout.getWarpsPerCTA())) {
shapePerCTATile.push_back(reg * thread * warp);
}
shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile));
sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread));

unsigned K = aShapePerCTA[2];

Expand Down
4 changes: 1 addition & 3 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ struct ReduceOpConversion
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
resultLayout, resultTy, true);
auto resultShape = resultTy.getShape();
auto resultCTATile = getShapePerCTATile(resultLayout);
assert(resultIndices.size() == resultElems);

SmallVector<Value> resultVals(resultElems);
Expand All @@ -359,8 +358,7 @@ struct ReduceOpConversion
for (size_t resultIdx = 0, resultDim = resultShape.size();
resultIdx < resultDim; ++resultIdx) {
auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1;
if (resultCTATile[resultIdx] > smemShape[smemIdx] ||
resultShape[resultIdx] > smemShape[smemIdx]) {
if (resultShape[resultIdx] > smemShape[smemIdx]) {
// When srcShape smaller then src sizePerThread, only srcShape
// elements is accumulated in smem. Modulo smemShape effectively
// replicates srcShape elements to src sizePerThread.
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,6 @@ SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
namespace LLVM {
using namespace mlir::triton;
using mlir::triton::gpu::getOrder;
using mlir::triton::gpu::getSizePerThread;

Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) {
auto i1ty = rewriter.getIntegerType(1);
Expand Down
41 changes: 18 additions & 23 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace mlir {
namespace triton {
namespace gpu {

LinearEncodingAttr toLinearEncoding(RankedTensorType type) {
return toLinearEncoding(type.getEncoding(), type.getShape());
}

LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
auto linearLayout = toLinearLayout(shape, layout);
return LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout));
Expand Down Expand Up @@ -121,29 +125,8 @@ SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType) {
return llAttr.getContigPerThread();
}

SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
auto sizePerThread = distributedLayout.getSizePerThread();
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
// ThreadsPerWarp does not align with this function for slice layout
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
}
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
assert(sizePerThread.size() == threadsPerWarp.size() &&
sizePerThread.size() == warpsPerCTA.size());
SmallVector<unsigned> shape;
for (auto [size, thread, warp] :
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
shape.push_back(size * thread * warp);
}
return shape;
} else {
llvm::report_fatal_error("getShapePerCTATile not implemented");
return SmallVector<unsigned>();
}
SmallVector<unsigned> getShapePerCTATile(RankedTensorType type) {
return toLinearEncoding(type).getShapePerCTATile();
}

bool isExpensiveView(Type srcType, Type dstType) {
Expand Down Expand Up @@ -1283,6 +1266,18 @@ SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
return basesPerDimImpl(bases, kRegister, rank);
}

SmallVector<unsigned> LinearEncodingAttr::getShapePerCTATile() const {
auto sizePerThread = getSizePerThread();
auto threadsPerWarp = getThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<unsigned> shape;
for (auto [size, thread, warp] :
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
shape.push_back(size * thread * warp);
}
return shape;
}

SmallVector<unsigned> LinearEncodingAttr::getDefaultOrder() const {
return getOrder();
}
Expand Down
2 changes: 1 addition & 1 deletion test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// expected-remark @below {{scratch offset = 0, size = 4608}}
%a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL>
// expected-remark @below {{scratch offset = 0, size = 4352}}
// expected-remark @below {{scratch offset = 0, size = 2304}}
%b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>

%c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
Expand Down
5 changes: 4 additions & 1 deletion test/Conversion/amd/buffer_load_store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
// CHECK-LABEL: buffer_store
tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1
// CHECK: %[[offset:.*]] = llvm.select %[[c_mask]]
// CHECK: %[[w_mask:.*]] = llvm.mlir.constant(true) : i1
// CHECK: %[[mask:.*]] = llvm.and %[[c_mask]], %[[w_mask]]
// CHECK: %[[offset:.*]] = llvm.select %[[mask]]
// CHECK: %[[aux:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]], {{.*}}, %[[aux]]
%c256_i32 = arith.constant 256 : i32
Expand Down
25 changes: 23 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1194,9 +1194,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: convert_layout_mmav3_transpose
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
// CHECK-COUNT-128: st.shared.b8
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32>
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
// CHECK-COUNT-16: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK: llvm.load {{.*}} -> vector<4xi32>
%0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
tt.return
}
Expand Down
10 changes: 5 additions & 5 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ LogicalResult ExtractSliceOp::verify() {
}

auto srcShape = srcTy.getShape();
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);

// ExtractSlice only supports slicing where offsets and sizes are multiples of
// shapePerCTATile. This condition ensures that slice has the same layout as
Expand Down Expand Up @@ -117,6 +112,11 @@ LogicalResult ExtractSliceOp::verify() {
sizes.push_back(resultDimSize);
}

auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcTy);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);
if (sizes[0] % shapePerCTATile[0] != 0 ||
sizes[1] % shapePerCTATile[1] != 0) {
return emitError() << "sizes [" << sizes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ struct ExtractSliceOpConversion
auto resultTy = cast<RankedTensorType>(op.getType());
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy);
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout);
auto totalSizePerThread = product<unsigned>(sizePerThread);
auto contigPerThread = triton::gpu::getContigPerThread(srcTy);
auto totalContigPerThread = product<unsigned>(contigPerThread);
auto order = triton::gpu::getOrder(srcTy);

// Calculate valid total number of workers in each dimension
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
Expand All @@ -94,21 +94,21 @@ struct ExtractSliceOpConversion

// The diagram above illustrates the graphical representation of the
// skipElems, tensorStride, and lastIdx variables.
auto skipElems = CTAOffsets[order[1]] *
(elemsPerThread[order[0]] * sizePerThread[order[1]]) +
CTAOffsets[order[0]] * totalSizePerThread;
auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] *
contigPerThread[order[1]]) +
CTAOffsets[order[0]] * totalContigPerThread;
auto tensorStride =
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread;
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread;
auto lastIdx =
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
elemsPerThread[order[0]] * sizePerThread[order[1]] +
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread;
elemsPerThread[order[0]] * contigPerThread[order[1]] +
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread;

assert(lastIdx <= vals.size());

SmallVector<Value> resultVals;
for (int i = skipElems; i < lastIdx; i += tensorStride) {
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) {
for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) {
assert(i < lastIdx);
resultVals.push_back(vals[i]);
}
Expand Down
54 changes: 14 additions & 40 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,50 +33,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc, const AMD::TargetInfo &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
Value mask = b.int_val(1, 1);
Value mask = b.true_val();
auto tid = getThreadId(rewriter, loc);
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
if (tensorTy) {
auto layout = tensorTy.getEncoding();
// To remove this use, port https://github.com/triton-lang/triton/pull/5432
// to the AMDGPU dialect
auto layout = cast<DistributedEncodingTrait>(tensorTy.getEncoding());
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto threadOrder = triton::gpu::getThreadOrder(tensorTy);
SmallVector<unsigned> warpOrder(rank);
if (auto enc = dyn_cast<DotOperandEncodingAttr>(layout)) {
warpOrder =
triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1);
} else {
warpOrder = triton::gpu::getWarpOrder(tensorTy);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = b.i32_val(triton::gpu::getWarpSize(layout));
Value laneId = b.urem(tid, warpSize);
Value warpId = b.udiv(tid, warpSize);
// TODO: [DOT LL]
// The delinearize function is not entirely correct for certain layouts,
// such as wgmma. The correct approach is to convert a legacy layout to its
// corresponding linear layout and use the linear layout's
// getFreeVariableMasks to identify redundant elements.
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
b.add(b.mul(multiDimWarpId[dim], b.i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = b.and_(mask,
b.icmp_slt(b.mul(threadDim, b.i32_val(sizePerThread[dim])),
b.i32_val(shape[dim])));
}
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
auto kLane = StringAttr::get(rewriter.getContext(), "lane");
auto kWarp = StringAttr::get(rewriter.getContext(), "warp");
auto maskLane =
std::get<1>(delinearize(rewriter, loc, layout, shape, kLane, laneId));
auto maskWarp =
std::get<1>(delinearize(rewriter, loc, layout, shape, kWarp, warpId));
mask = b.and_(maskLane, maskWarp);

// Do not write duplicated data when multicast is enabled
if (triton::gpu::getNumCTAs(layout) > 1) {
auto _0 = b.i32_val(0);
Expand All @@ -87,6 +60,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);

auto rank = tensorTy.getRank();
for (unsigned dim = 0; dim < rank; ++dim) {
// Skip when multicast is not enabled in this dimension
if (CTAsPerCGA[dim] == CTASplitNum[dim])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ using namespace mlir::triton::NVIDIA;

using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::MemDescType;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::NVMMASharedEncodingAttr;
Expand Down Expand Up @@ -379,7 +378,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
int N = instrShape[1];
int K = instrShape[2];
bool zeroAcc = isZeroConst(c);
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
auto instrMNK = mmaEncoding.getInstrShape();
auto warpSize = mmaEncoding.getWarpsPerCTA();
auto shapePerCTATile = SmallVector<unsigned>{instrMNK[0] * warpSize[0],
instrMNK[1] * warpSize[1]};
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
Expand Down

0 comments on commit 161e713

Please sign in to comment.