From 456735bc0461c76ae8145375d890021e72a76c9a Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 26 Feb 2025 15:22:43 -0800 Subject: [PATCH] [Blackwell] Prevent the tmem allocator bitmap to go out of bound There was a a potential memory corruption in the tmem allocator, this makes the code more robust. --- .../Transforms/TensorMemoryAllocation.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp index 101cc096bad5..4e92293fdb06 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp @@ -43,7 +43,7 @@ struct MemoryBitMap { } void alloc(const TMemChunk &chunk) { // Ensure the underlying data fits the allocation. - while ((chunk.startCol + chunk.numCols) * chunk.numRows >= elements.size()) + while ((chunk.startCol + chunk.numCols) * kNumRows >= elements.size()) elements.resize(2 * elements.size(), false); for (int i = 0; i < chunk.numCols; i++) { @@ -92,8 +92,13 @@ struct MemoryBitMap { } private: - bool isUsed(int row, int col) const { return elements[row + col * kNumRows]; } + bool isUsed(int row, int col) const { + if (row + col * kNumRows >= elements.size()) + return false; + return elements[row + col * kNumRows]; + } void setUsed(int row, int col, bool used) { + assert(row + col * kNumRows < elements.size()); elements[row + col * kNumRows] = used; }