Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
add a unit test

linter

update

update

update

update

update

clang-format

rebase
  • Loading branch information
sfzhu93 committed Jan 30, 2025
1 parent 89c0b0a commit 23fa565
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"
#include <cstdint>
#include <optional>

using namespace mlir;
using namespace mlir::scf;

/// This class is used to report the scheduling error. It is used by
/// the pipeline expander.
class PipelineErrorReporter {
protected:
ForOp forOp;
/// Collect the root defining ops in IfOps. There could be multiple root
/// defining ops in IfOps, as there are then branches and else branches.
DenseSet<Operation *> rootDefiningOps;

/// Recursively find the root defining op of the value in IfOps by traversing
/// back the index of an OpResult and yielded values.
void findRootDefiningOp(Operation *op, unsigned int resultNumber);

/// Get the operand from the yield operation of the loop, which is the real
/// value of the loop-carried dependency.
std::optional<Value> getBlockArgYieldValueFromForLoop(BlockArgument arg);

/// Find the loop-carried dependency that really causes the scheduling error,
/// going into nested operations of IfOps.
void findRootSchedulingErrorLoopCarryDep(Operation *consumer,
Operation *producer, Value operand);

public:
explicit PipelineErrorReporter(ForOp forOp) : forOp(forOp) {}
PipelineErrorReporter(const PipelineErrorReporter &) = delete;
PipelineErrorReporter &operator=(const PipelineErrorReporter &) = delete;
PipelineErrorReporter(PipelineErrorReporter &&) = delete;
PipelineErrorReporter &operator=(PipelineErrorReporter &&) = delete;

/// Print the scheduling error message using MLIR's diagnostic engine.
/// Depending on whether it is a loop-carried dependency, we print different
/// messages. When distance is 0, it means the consumer and producer are in
/// the same iteration. We are not supposed to have scheduling error in this
/// case, as we have addressed the potential data dependency conflicts.
///
/// When distance is 1, we find the root scheduling error, and print the
/// diagnostic message.
void printSchedulingError(int64_t distance, Operation *consumer,
Operation *producer, Value operand);
};
#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_triton_library(TritonGPUTransforms
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/OuterLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/PipelineErrorReporter.cpp
Pipeliner/TestPipelineAssignLatencies.cpp
Pipeliner/TestPipelineScheduleLoop.cpp
Pipeliner/SoftwarePipeliner.cpp
Expand Down
114 changes: 114 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineErrorReporter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Reporting error messages for scheduling errors.

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"
#include <cstdint>

#include "triton/Dialect/TritonGPU/Transforms/PipelineErrorReporter.h"

#define DEBUG_TYPE "triton-pipeline-error-reporter"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;
using namespace mlir::scf;

void PipelineErrorReporter::findRootDefiningOp(Operation *op,
unsigned int resultNumber) {
LDBG("findRootDefiningOp: " << *op << "\n from its " << resultNumber
<< "th result\n");

if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
// then branch.
{
auto operandFromThenBranch = ifOp.thenYield()->getOperand(resultNumber);
if (auto opResult = dyn_cast<OpResult>(operandFromThenBranch)) {
findRootDefiningOp(operandFromThenBranch.getDefiningOp(),
opResult.getResultNumber());
} else if (!dyn_cast<BlockArgument>(operandFromThenBranch)) {
rootDefiningOps.insert(operandFromThenBranch.getDefiningOp());
}
}
// else branch.
{
auto operandFromElseBranch = ifOp.thenYield()->getOperand(resultNumber);
if (auto opResult = dyn_cast<OpResult>(operandFromElseBranch)) {
findRootDefiningOp(opResult.getDefiningOp(),
opResult.getResultNumber());
} else if (!dyn_cast<BlockArgument>(operandFromElseBranch)) {
rootDefiningOps.insert(operandFromElseBranch.getDefiningOp());
}
}
} else {
rootDefiningOps.insert(op);
}
}

std::optional<Value>
PipelineErrorReporter::getBlockArgYieldValueFromForLoop(BlockArgument arg) {
if (arg.getOwner() != forOp.getBody())
return std::nullopt;
// Ignore induction variable.
if (arg.getArgNumber() == 0)
return std::nullopt;
return forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}

void PipelineErrorReporter::findRootSchedulingErrorLoopCarryDep(
Operation *consumer, Operation *producer, Value operand) {
DenseSet<Operation *> rootDefiningOps;
LDBG("findRootSchedulingErrorLoopCarryDep: this operand is not ready at "
"the consumer: "
<< operand << "\n");
if (auto arg = dyn_cast<BlockArgument>(operand)) {
LDBG("operand is a block arg. Arg number: " << arg.getArgNumber() << "\n");
// This is a loop-carried dependency. Find which value yields the arg.
auto yieldValue = getBlockArgYieldValueFromForLoop(arg);
if (!yieldValue) {
LDBG("no yield value for arg " << arg << " -> BAIL");
return;
}

assert(producer == yieldValue->getDefiningOp() &&
"producer should be the def of the yield value of operand");
// We repeat the process of computing the producer, because we need to
// know the result number of the producer, which is only available in the
// yield value.
LDBG("yield value (loop-carry): " << *yieldValue << "\n");
if (auto opResult = dyn_cast<OpResult>(*yieldValue)) {
findRootDefiningOp(producer, opResult.getResultNumber());
} else
rootDefiningOps.insert(producer);
}
}

void PipelineErrorReporter::printSchedulingError(int64_t distance,
Operation *consumer,
Operation *producer,
Value operand) {

std::string errorMessage = "operation scheduled before its operands.";
std::string likelyBuggyMessage = "This is likely to be a bug. Please "
"report it.";
// We only find the root defining ops for loop-carried dependencies.
// When distance is 0, we let the set of root defining ops to be empty.
if (distance > 0) {
// TODO: I only find the root defining ops of the producer. We should also
// find the root user ops of the consumer.
findRootSchedulingErrorLoopCarryDep(consumer, producer, operand);
}
if (rootDefiningOps.empty()) {
// We failed to find the root defining ops. Whether the disntance is 0 or
// not, an empty set means we have some bugs in the pipeline expander. We
// should let the user help report the bug.
consumer->emitError() << errorMessage << " " << likelyBuggyMessage;
} else {
consumer->emitError() << errorMessage;
for (auto op : rootDefiningOps) {
op->emitError() << "This line likely causes scheduling conflict. "
"Consider moving it "
"to an earlier position within the loop body.";
}
}
}
20 changes: 12 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"

#include "triton/Dialect/TritonGPU/Transforms/PipelineErrorReporter.h"

#define DEBUG_TYPE "triton-loop-pipelining"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
Expand Down Expand Up @@ -248,7 +250,9 @@ bool LoopPipelinerInternal::verifySchedule() {
continue;
int64_t producerCycle = it->second;
if (consumerCycle < producerCycle - numCylesPerIter * distance) {
consumer->emitError("operation scheduled before its operands");
PipelineErrorReporter errorReporter(forOp);
errorReporter.printSchedulingError(distance, consumer, producer,
operand);
return false;
}
}
Expand Down Expand Up @@ -407,8 +411,8 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
// Keep track of the kernel argument associated to each version of the
// values passed to the kernel.
llvm::SmallVector<Value> newLoopArg;
// For existing loop argument initialize them with the right version from the
// prologue.
// For existing loop argument initialize them with the right version from
// the prologue.
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
Expand Down Expand Up @@ -439,8 +443,8 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
}

// Create the new kernel loop. When we peel the epilgue we need to peel
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
// iterations.
// `numStages - 1` iterations. Then we adjust the upper bound to remove
// those iterations.
Value newUb = forOp.getUpperBound();
if (peelEpilogue) {
Type t = ub.getType();
Expand Down Expand Up @@ -750,9 +754,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
}
}
if (dynamicLoop) {
// Select return values from this stage (live outs) based on predication.
// If the stage is valid select the peeled value, else use previous stage
// value.
// Select return values from this stage (live outs) based on
// predication. If the stage is valid select the peeled value, else use
// previous stage value.
for (auto pair : llvm::enumerate(returnValues)) {
unsigned ri = pair.index();
auto [mapVal, currentVersion] = returnMap[ri];
Expand Down
147 changes: 147 additions & 0 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,150 @@ def kernel_pipe_error(in_ptr, out_ptr):
i = torch.empty(64 * 64, dtype=torch.float32).cuda()
o = torch.empty(64 * 64, dtype=torch.float32).cuda()
kernel_pipe_error[(1, )](i, o)


def test_remark_swp_op_before_operands_persistent_matmul(capfd, fresh_triton_cache):

# this example is from https://github.com/triton-lang/triton/issues/5172
@triton.jit
def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_SMS: tl.constexpr, #
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1

# tile_id = start_pid - NUM_SMS
tile_id = start_pid

ki = -1

offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)

num_pid_in_group = GROUP_SIZE_M * num_pid_n

pid_m = 0
pid_n = 0
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
# tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)

if ki == k_tiles - 1:
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tile_id += NUM_SMS # this line is newly added

with enable_remark_context():
dtype = torch.float16

M = 8192
N = 8192
K = 512
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)

b = b.T.contiguous().T

# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"

# equals to 132 on H100
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

M, K = a.shape
K, N = b.shape
dtype = a.dtype
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (min(NUM_SMS,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_persistent[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=128, #
BLOCK_SIZE_N=256, #
BLOCK_SIZE_K=64, #
GROUP_SIZE_M=8, #
NUM_SMS=NUM_SMS, #
num_stages=3, #
num_warps=8, #
)
# print(c)

_, err = capfd.readouterr()
# Split the output into lines for easier processing
lines = err.splitlines()
# Define the expected strings in order
expected_strings = [
"error: operation scheduled before its operands.", "if ki == 0:",
"error: This line likely causes scheduling conflict. Consider moving it", "tile_id += NUM_SMS"
]
# Initialize an index to track the position in expected_strings
index = 0
# Iterate over each line in the output
for line in lines:
# Check if the current expected string is in the line
if expected_strings[index] in line:
# Move to the next expected string
index += 1
# If all expected strings have been found, break out of the loop
if index == len(expected_strings):
break
# Check if all expected strings were found
if index != len(expected_strings):
missing_string = expected_strings[index]
raise AssertionError(f"Missing expected string: '{missing_string}' from {err}")

0 comments on commit 23fa565

Please sign in to comment.