Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 716a521 #3361

Merged
merged 9 commits into from
Feb 6, 2025
8 changes: 7 additions & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,13 @@ void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
// transpose(x, order=[0, 1, ...]) -> x
if (isIota(getOrder())) {
return getSrc();
// If the source and result types are the same, we can return the source
// If their layout is different (even if structurally equivalent), we need
// to insert a convert_layout in between as otherwise ::fold complains
// We do this in CanonicalizeConvertFromTranspose
if (getSrc().getType() == getType()) {
return getSrc();
}
}

// transpose(transpose(x)) -> transpose(x)
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ struct CanonicalizeConvertFromTranspose
mlir::LogicalResult
matchAndRewrite(triton::TransOp op,
PatternRewriter &rewriter) const override {
// transpose(x, order=[0, 1, ...]) -> x
// We turn it into a (trivial) convert_layout that may be folded away
if (isIota(op.getOrder())) {
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, op.getType(),
op.getSrc());
return success();
}

// If the layouts are structurally the same, the convert is trivial
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert || !isConvertTrivial(convert))
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ add_triton_library(TritonGPUTransforms
OptimizeThreadLocality.cpp
Pipeliner/AssignLatencies.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/OuterLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/TestPipelineAssignLatencies.cpp
Pipeliner/TestPipelineScheduleLoop.cpp
Expand Down
131 changes: 0 additions & 131 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp

This file was deleted.

23 changes: 0 additions & 23 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ namespace gpu {
#define GEN_PASS_DEF_TRITONGPUPIPELINE
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

static void tryAndPipelineOuterLoop(scf::ForOp forOp) {
mlir::triton::PipeliningOption options;
bool foundSchedule = false;
// Limit 2 stages to not require extra shared memory.
foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options);
if (!foundSchedule)
return;
IRRewriter rewriter(forOp->getContext());
rewriter.setInsertionPoint(forOp);
FailureOr<scf::ForOp> newForOp =
mlir::triton::pipelineForLoop(rewriter, forOp, options);
}

static scf::ForOp pipelineLoop(scf::ForOp forOp, int numStages) {
mlir::triton::PipeliningOption options;

Expand Down Expand Up @@ -92,17 +79,12 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
if (loops.empty())
return;

llvm::SmallSetVector<scf::ForOp, 8> outerLoops;
llvm::SmallVector<scf::ForOp> pipelinedLoops;
for (scf::ForOp forOp : loops) {
auto outerLoop = dyn_cast<scf::ForOp>(forOp->getParentOp());
int loopNumStages = getNumStagesOrDefault(forOp);
scf::ForOp pipelinedFor = pipelineLoop(forOp, loopNumStages);
if (pipelinedFor != nullptr)
pipelinedLoops.push_back(pipelinedFor);
if (pipelinedFor != nullptr && outerLoop &&
getNumStagesOrDefault(outerLoop) > 1)
outerLoops.insert(outerLoop);
}

// There is a hard dependency between load pipelining and the TC05MMA
Expand All @@ -122,11 +104,6 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
if (applyPatternsGreedily(getOperation(), std::move(patterns)).failed())
return signalPassFailure();

// Try to pipeline the outer loop to overlap the prologue and epilogue of
// the inner loop.
for (scf::ForOp outerLoop : outerLoops)
tryAndPipelineOuterLoop(outerLoop);

// Re-collect loop ops
loops.clear();
getOperation()->walk([&](scf::ForOp forOp) {
Expand Down
6 changes: 2 additions & 4 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import triton.language as tl
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
import traceback
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_xpu
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_xpu


def format_exception(type, value, tb):
Expand Down Expand Up @@ -362,7 +362,7 @@ def kernel():


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
def test_fp8_support(dtype):
def test_fp8_support(fresh_triton_cache, dtype):
warning_dtypes = []
supported_dtypes = [tl.float8e5]
if is_cuda():
Expand All @@ -377,8 +377,6 @@ def test_fp8_support(dtype):
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
elif is_xpu():
supported_dtypes += [tl.float8e4b15, tl.float8e4nv]
elif is_interpreter():
supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]

@triton.jit
def dtype_kernel(dtype: tl.constexpr):
Expand Down
69 changes: 69 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5920,6 +5920,75 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
assert torch.equal(z, x)


dot_layouts = [
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=4),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=1, k_width=4),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=2),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=1),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=1),
]

shared_layouts = [
SharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]),
SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]),
SharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]),
]


@pytest.mark.parametrize("M, N", [[16, 32]])
@pytest.mark.parametrize("dtype", ['float16', 'float8e5', 'float32'])
@pytest.mark.parametrize("shared_layout", shared_layouts)
@pytest.mark.parametrize("dist_layout", filter_layouts(dot_layouts))
def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, tmp_path: pathlib.Path):
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
if dtype == "float32":
mlir_dtype = "f32"
elif dtype == "float16":
mlir_dtype = "f16"
elif dtype == "float8e5":
mlir_dtype = "f8E5M2"

layouts = f"""
#dist = {dist_layout}
#shared = {shared_layout}
#smem = #ttg.shared_memory
"""
ir = layouts + f"""
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>>
%2 = tt.splat %arg0 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>
%3 = tt.splat %arg1 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist>
%7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist>
%8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist>
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>
%12 = ttg.local_alloc %11 : (tensor<{M}x{N}x{mlir_dtype}, #dist>) -> !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem>
%13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> -> tensor<{M}x{N}x{mlir_dtype}, #dist>
%14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist>
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>
tt.return
}}
}}
"""

x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
z = torch.empty_like(x, device=device)

temp_file = tmp_path / "test_local_load_store_dot.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))

kernel[(1, 1, 1)](x, z)
assert torch.equal(z, x)


mma_layouts = [
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ run_core_tests() {
ensure_spirv_dis

TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=language \
pytest -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
pytest -k "not test_local_load_store_dot" -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py

TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=subprocess \
pytest -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/test_subprocess.py
Expand Down
2 changes: 0 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -974,8 +974,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
Expand Down
16 changes: 16 additions & 0 deletions test/TritonGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,19 @@ tt.func @infer_trans(%arg0: tensor<32x32xf32, #linear>) -> tensor<32x32xf32, #bl
}

}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
#dot_t = #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 64], [0, 128]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 16], [0, 32]], block = []}>
#dot_linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [64, 0], [128, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @simplify_trans_trans
tt.func public @simplify_trans_trans(%arg0: tensor<256x256xf32, #dot_linear>) -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> {
// CHECK-NEXT: ttg.convert_layout
%a = tt.trans %arg0 {order=array<i32: 1,0>} : tensor<256x256xf32, #dot_linear> -> tensor<256x256xf32, #dot_t>
%b = tt.trans %a {order=array<i32: 1,0>} : tensor<256x256xf32, #dot_t> -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
tt.return %b : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
}
}
11 changes: 1 addition & 10 deletions test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
// CHECK: scf.for
// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc
// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc
// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
Expand All @@ -121,7 +122,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
// CHECK: ttg.async_copy_global_to_local
// CHECK: scf.for
// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]],
// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]],
// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32}
Expand All @@ -144,15 +144,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32}
// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]]
// CHECK: ttg.async_wait {num = 0 : i32}
// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]]
// CHECK: ttg.async_copy_global_to_local
// CHECK scf.yield
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
Expand Down
Loading