Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into mogball/ws5
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball committed Feb 28, 2025
2 parents a9d96cb + 37ff43c commit e54a900
Show file tree
Hide file tree
Showing 16 changed files with 768 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store", [MemoryEffects<[MemWrite]>]> {
let hasVerifier = 1;
}

def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [MemoryEffects<[MemWrite]>]> {
def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "allocate tensor memory";
let description = [{
This operation allocates buffer in tensor memory and return a descriptor
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,25 @@ LogicalResult TMEMAllocOp::verify() {
return success();
}

// TMEMAllocOp
void TMEMAllocOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Operation *op = getOperation();
// If allocation is immutable, mark it as no side effect allow things like
// CSE, DCE to work in early compiler passes.
// After the memory offset is computed, we attach the true side effect to the
// op.
if (!getType().getMutableMemory() && !op->hasAttr("tensor_memory_col_offset"))
return;
effects.emplace_back(MemoryEffects::Allocate::get(),
mlir::triton::nvidia_gpu::TensorMemory::get());
if (getSrc())
effects.emplace_back(MemoryEffects::Write::get(),
getOperation()->getOpResult(0),
mlir::triton::nvidia_gpu::TensorMemory::get());
}

bool isDescendingOrder(triton::gpu::MemDescType type) {
auto order = triton::gpu::getOrder(type);
auto rank = type.getRank();
Expand Down
22 changes: 18 additions & 4 deletions lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,22 @@ struct TCGen5MMAScaleSharedToTmemConversion
: public OpRewritePattern<TCGen5MMAScaledOp> {
using OpRewritePattern<TCGen5MMAScaledOp>::OpRewritePattern;

bool lowerScaleToTmem(OpOperand &operand, PatternRewriter &rewriter) const {
// Create a tmem_copy of scales from shared memory to tmem. `rows` is the M or
// N of the MMA operation (for LHS or RHS respectively).
bool lowerScaleToTmem(OpOperand &operand, PatternRewriter &rewriter,
int rows) const {
Location loc = operand.getOwner()->getLoc();
MLIRContext *context = operand.getOwner()->getContext();
Attribute tensorMemorySpace = TensorMemorySpaceAttr::get(context);
auto oldType = cast<MemDescType>(operand.get().getType());
auto numElems = product(oldType.getShape());
Type elType = oldType.getElementType();
SwizzledSharedEncodingAttr oldEncoding =
cast<SwizzledSharedEncodingAttr>(oldType.getEncoding());
CTALayoutAttr CTALayout = getCTALayout(oldEncoding);
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
ArrayRef<int64_t> shape = oldType.getAllocShape();
// Distribute the scales across the rows of the MMA operation.
SmallVector<int64_t> shape = {rows, numElems / rows};
Attribute scaleEncoding = TensorMemoryScalesEncodingAttr::get(
context, CTASplitNum[0], CTASplitNum[1]);
Type scaleAType =
Expand All @@ -84,12 +89,21 @@ struct TCGen5MMAScaleSharedToTmemConversion
MLIRContext *context = op->getContext();
auto aScaleType = op.getAScale().getType();
auto bScaleType = op.getBScale().getType();
int blockM = op.getA()
.getType()
.getShape()[op.getA().getType().getShape().size() - 2];
int blockN = op.getB()
.getType()
.getShape()[op.getB().getType().getShape().size() - 1];
int blockK = op.getA()
.getType()
.getShape()[op.getA().getType().getShape().size() - 1];
bool anyChanged = false;
if (isa<SwizzledSharedEncodingAttr>(aScaleType.getEncoding())) {
anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter);
anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter, blockM);
}
if (isa<SwizzledSharedEncodingAttr>(bScaleType.getEncoding())) {
anyChanged = lowerScaleToTmem(op.getBScaleMutable(), rewriter);
anyChanged = lowerScaleToTmem(op.getBScaleMutable(), rewriter, blockN);
}
return LogicalResult::success(anyChanged);
}
Expand Down
6 changes: 4 additions & 2 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <optional>
#include <optional>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -1774,7 +1774,9 @@ void init_triton_ir(py::module &&m) {
std::string funcToDump;
if (!haveDump) {
funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP");
if (!funcToDump.empty())
bool isEnvValueBool =
triton::tools::isEnvValueBool(funcToDump).has_value();
if (!funcToDump.empty() && !isEnvValueBool)
haveDump = true;
}
if (haveDump) {
Expand Down
14 changes: 13 additions & 1 deletion python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ class Tensor(NamedTuple):
stride: tuple


@triton.jit
def _namedtuple_create_func0(shape, ptr, stride):
return Tensor(shape=shape, ptr=ptr, stride=stride)


@triton.jit
def _namedtuple_create_func1(shape, ptr, stride):
tensor = Tensor(shape=shape, ptr=ptr, stride=stride)
return tensor


@triton.jit
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
Expand All @@ -127,7 +138,8 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride)
Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride)
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
Expand Down
62 changes: 40 additions & 22 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import itertools
import os
import shutil
import pathlib

Expand Down Expand Up @@ -553,28 +554,45 @@ def compiled_hook(*args, **kwargs):

@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())
def test_within_2gb(device, fresh_triton_cache) -> None:

@triton.jit
def kernel_add(a):
tl.load(a)

# This is the attribute we want to test
pointer_range_32 = None

def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = [k for k, v in kwargs["compile"]["configs"][0].items() if ['tt.pointer_range', 32] in v]

JITFunction.cache_hook = cache_hook
# In warmup we assume that the pointer range is 32 bits
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == [(0, )]
# Torch tensor > 2GB
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
# Torch tensor <= 2GB
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == [(0, )]
default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0")
from triton.backends import backends

amd_backend = backends["amd"]
try:
use_buffer_ops_opts = ["1", "0"]
# The ranges should only be available when buffer ops are enabled
pointer_ranges = [[(0, )], []]
for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges):
# Set AMDGCN_USE_BUFFER_OPS
amd_backend.compiler.use_buffer_ops.cache_clear()
os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops

@triton.jit
def kernel_add(a):
tl.load(a)

# This is the attribute we want to test
pointer_range_32 = None

def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = [
k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v
]

JITFunction.cache_hook = cache_hook
# In warmup we assume that the pointer range is 32 bits
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == pointer_range
# Torch tensor > 2GB
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
# Torch tensor <= 2GB
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == pointer_range
finally:
amd_backend.compiler.use_buffer_ops.cache_clear()
os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops


def test_function_arguments(device):
Expand Down
34 changes: 22 additions & 12 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def _check_fn_args(node, fn, args):
)


def _is_namedtuple(val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")


def _apply_to_tuple_values(value, fn):
if _is_namedtuple(type(value)):
fields = value._fields
elif isinstance(value, language.tuple):
fields = value.type.fields
else:
assert False, f"Unsupported type {type(value)}"

vals = [fn(v) for v in value]
types = [v.type for v in vals]
return language.tuple(vals, language.tuple_type(types, fields))


def flatten_values_to_ir(values: Iterable[base_value]):
handles = []
for v in values:
Expand Down Expand Up @@ -349,9 +366,6 @@ def _is_constexpr_global(self, name):

return False

def _is_namedtuple(self, val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")

def _define_name_lookup(self):

def local_lookup(name: str, absent):
Expand All @@ -370,7 +384,7 @@ def global_lookup(name: str, absent):
getattr(val, "__triton_builtin__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
self._is_namedtuple(val),
_is_namedtuple(val),
self._is_constexpr_global(name), #
# Allow accesses to globals while visiting an ast.arg
# because you should be able to do
Expand Down Expand Up @@ -451,7 +465,7 @@ def visit_Return(self, node):

def decay(value):
if isinstance(value, language.tuple):
return language.tuple([decay(v) for v in value.values])
return _apply_to_tuple_values(value, decay)
elif isinstance(value, (language.constexpr, int, float)):
return semantic.to_tensor(value, self.builder)
return value
Expand Down Expand Up @@ -575,13 +589,8 @@ def assignTarget(self, target, value):
def visit_Assign(self, node):
# construct values to assign
def _sanitize_value(value):
if self._is_namedtuple(type(value)):
vals = [_sanitize_value(v) for v in value]
types = [v.type for v in vals]
fields = type(value)._fields
return language.tuple(vals, language.tuple_type(types, fields))
if isinstance(value, language.tuple):
return language.tuple([_sanitize_value(v) for v in value.values])
return _apply_to_tuple_values(value, _sanitize_value)
native_nontensor_types = (language.dtype, language.tuple)
value = _unwrap_if_constexpr(value)
if value is not None and \
Expand Down Expand Up @@ -1253,7 +1262,8 @@ def visit_Call(self, node):

if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
return fn(*args, **kws)
ret = fn(*args, **kws)
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret

def visit_Constant(self, node):
return constexpr(node.value)
Expand Down
Loading

0 comments on commit e54a900

Please sign in to comment.