Skip to content

Commit

Permalink
[BACKEND] Don't allocate shmem for warps with repeated data in tt.scan (
Browse files Browse the repository at this point in the history
#5910)

It turns out that the previous changes within reduce to support LLs had
already trimmed its shmem memory use to the right size.
  • Loading branch information
lezcano authored Feb 13, 2025
1 parent 464d1f1 commit de650ad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ bool ScanLoweringHelper::isSupported() {
}

unsigned ScanLoweringHelper::getScratchSizeInElems() {
unsigned numWarps = lookupNumWarps(scanOp);
unsigned numWarps = product(getEncoding().getWarpsPerCTA());
unsigned numNonAxisElementsPerWarp =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
Expand Down
10 changes: 10 additions & 0 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -615,4 +615,14 @@ tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK-NEXT: size = 1024
}

// CHECK-LABEL: scan_alloc
tt.func @scan_alloc(%x : tensor<8x16xf32, #AL>) {
// CHECK: offset = 0, size = 128
%a = "tt.scan"(%x) <{axis = 0 : i32, reverse = false}>({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.scan.return %add : f32
}) : (tensor<8x16xf32, #AL>) -> tensor<8x16xf32, #AL>
tt.return
}
}

0 comments on commit de650ad

Please sign in to comment.