Skip to content

Commit

Permalink
Lowering of BufferLoadToLocal to buffer load to lds
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexAUT committed Feb 25, 2025
1 parent dce695e commit 489d5ee
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 49 deletions.
161 changes: 161 additions & 0 deletions test/Conversion/amd/buffer_load_to_local_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: buffer_load_to_local_simple
tt.func public @buffer_load_to_local_simple(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f16>,
%arg2: tensor<32x64xi32, #blocked>,
%arg3: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per buffer load instruction
// CHECK: rocdl.make.buffer.rsrc %arg1
// CHECK-NOT: rocdl.make.buffer.rsrc
// CHECK-COUNT-8: rocdl.raw.ptr.buffer.load.lds
// CHECK-NOT: rocdl.raw.ptr.buffer.load.lds
%65 = amdgpu.buffer_load_to_local %arg1[%arg2], %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: buffer_load_to_local_vectorized_2xf16
tt.func public @buffer_load_to_local_vectorized_2xf16(
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>,
%arg3: i32) {
%1 = tt.splat %arg3: i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
// Each thread needs to load 8 elements and we load 2 (sizePerThread) per buffer load instruction
// CHECK: rocdl.make.buffer.rsrc
// CHECK-NOT: rocdl.make.buffer.rsrc
// CHECK-COUNT-4: rocdl.raw.ptr.buffer.load.lds
// CHECK-NOT: rocdl.raw.ptr.buffer.load.lds
%65 = amdgpu.buffer_load_to_local %arg1[%3], %arg2 : <f16>[tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// GFX950-LABEL: buffer_load_to_local_vectorized_8xf16
// GFX942-LABEL: buffer_load_to_local_vectorized_8xf16
tt.func public @buffer_load_to_local_vectorized_8xf16(
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>,
%arg3: i32) {
%1 = tt.splat %arg3: i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
// GFX950: rocdl.make.buffer.rsrc
// GFX950-NOT: rocdl.make.buffer.rsrc
// GFX950: rocdl.raw.ptr.buffer.load.lds
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds

// GFX942 does not support vectorization > 4bytes so we cannot lower it
// GFX942-NOT rocdl.raw.ptr.buffer.load.lds
// GFX942: amdgpu.amdgpu.buffer_load_to_local
%65 = amdgpu.buffer_load_to_local %arg1[%3], %arg2 : <f16>[tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: buffer_load_to_local_mask_other
// GFX950-LABEL: buffer_load_to_local_mask_other
tt.func public @buffer_load_to_local_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f16>,
%arg2: tensor<32x32xi32, #blocked>,
%arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
%arg4: i32) {
// We need the splat to allow the AxisAnalysis to work during lowering
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c31_i32 = arith.constant 31 : i32
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
%29 = arith.addi %arg4, %c31_i32 : i32
%30 = arith.divsi %29, %c32_i32 : i32
%31 = arith.cmpi sgt, %30, %c0_i32 : i32

%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%65 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked>
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>

%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>

// Each thread needs to load 4 elements and we load 1 (sizePerThread) per buffer load instruction
// Note that mask/other alignment is 1 so we need 4 conditionals

// CHECK: rocdl.raw.ptr.buffer.load.lds
// CHECK: _predicated_store

// CHECK: rocdl.raw.ptr.buffer.load.lds
// CHECK: _predicated_store

// CHECK: rocdl.raw.ptr.buffer.load.lds
// CHECK: _predicated_store

// CHECK: rocdl.raw.ptr.buffer.load.lds
// CHECK: _predicated_store

// CHECK-NOT: rocdl.raw.ptr.buffer.load.lds
// CHECK-NOT: _predicated_store

amdgpu.buffer_load_to_local %arg1[%arg2], %arg3 mask %67 other %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<32x32xi32, #blocked>] tensor<32x32xf16, #blocked> -> <32x32xf16, #shared, #smem, mutable>
tt.return
}
}

// -----


#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: buffer_load_to_local_cache_mods
tt.func public @buffer_load_to_local_cache_mods(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f16>,
%arg2: tensor<32x32xi32, #blocked>,
%arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) {
// The first constant 0 skips the LDS offset which is also 0
// CHECK: llvm.getelementptr
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
%1 = amdgpu.buffer_load_to_local %arg1[%arg2], %arg3 cacheModifier = ca: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>
// CHECK: llvm.getelementptr
// CHECK: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
%2 = amdgpu.buffer_load_to_local %arg1[%arg2], %arg3 cacheModifier = cg: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>
// CHECK: llvm.getelementptr
// CHECK: %[[aux_cv:.*]] = llvm.mlir.constant(17 : i32) : i32
// CHECK: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
%3 = amdgpu.buffer_load_to_local %arg1[%arg2], %arg3 cacheModifier = cv: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>

tt.return
}
}
22 changes: 22 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset,
return data;
}

void BufferEmitter::emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc,
Value offset, Value dst, Value pred,
triton::CacheModifier cm) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value, 6> commonArgs;
fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true,
commonArgs);
Type bufferType = getBufferOpType(type, false);
rewriter.create<ROCDL::RawPtrBufferLoadLdsOp>(
loc, TypeRange{},
ValueRange{
commonArgs[0], // Buffer descriptor
dst, // LDS base ptr
byteWidth, // Instr size
commonArgs[1], // Buffer offset
b.i32_val(0), // LDS offset
commonArgs[2], // Instruction offset
commonArgs[3], // AUX
},
ArrayRef<NamedAttribute>());
}

Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc,
Value offset, Value data, Value pred,
bool hasUsers) {
Expand Down
4 changes: 4 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ struct BufferEmitter {
Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred,
Value falseVal, CacheModifier cm);

// Emit a predicated rocdl.raw.ptr.buffer.load.lds
void emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc, Value offset,
Value dst, Value pred, CacheModifier cm);

// Emit a predicated rocdl.raw.ptr.buffer.atomic.* RMWOp
Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset,
Value data, Value pred, bool hasUsers);
Expand Down
Loading

0 comments on commit 489d5ee

Please sign in to comment.