Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
add a unit test

linter

update

update

update

update

update

clang-format

rebase

update

update

update

update

update

update

update

update

update
  • Loading branch information
sfzhu93 committed Feb 25, 2025
1 parent 924468e commit 30ed20c
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"
#include <cstdint>
#include <optional>

using namespace mlir;
using namespace mlir::scf;

/// This class is used to report the scheduling error. It is used by
/// the pipeline expander.
class PipelineErrorReporter {
protected:
ForOp forOp;

unsigned numStages = 0;
const DenseMap<Operation *, unsigned int> &stages;

/// Collect the root defining ops in IfOps. There could be multiple root
/// defining ops in IfOps, as there are then branches and else branches.
DenseSet<Operation *> rootDefiningOps;

/// Recursively find the root defining op of the value in IfOps by traversing
/// back the index of an OpResult and yielded values.
void findRootDefiningOp(Operation *op, unsigned int resultNumber);

DenseSet<Operation *> rootUserOps;

void findRootUserOp(Operation *op, const DenseSet<Operation *> &userOps);

std::optional<unsigned int> findStage(Operation *op);

/// Get the operand from the yield operation of the loop, which is the real
/// value of the loop-carried dependency.
std::optional<Value> getBlockArgYieldValueFromForLoop(BlockArgument arg);

/// Find the loop-carried dependency that really causes the scheduling error,
/// going into nested operations of IfOps.
void findRootSchedulingErrorLoopCarryDep(Operation *consumer,
Operation *producer, Value operand);

public:
explicit PipelineErrorReporter(
ForOp forOp, unsigned numStages,
const DenseMap<Operation *, unsigned int> &stages)
: forOp(forOp), numStages(numStages), stages(stages) {}
PipelineErrorReporter(const PipelineErrorReporter &) = delete;
PipelineErrorReporter &operator=(const PipelineErrorReporter &) = delete;
PipelineErrorReporter(PipelineErrorReporter &&) = delete;
PipelineErrorReporter &operator=(PipelineErrorReporter &&) = delete;

/// Print the scheduling error message using MLIR's diagnostic engine.
/// Depending on whether it is a loop-carried dependency, we print different
/// messages. When distance is 0, it means the consumer and producer are in
/// the same iteration. We are not supposed to have scheduling error in this
/// case, as we have addressed the potential data dependency conflicts.
///
/// When distance is 1, we find the root scheduling error, and print the
/// diagnostic message.
void printSchedulingError(int64_t distance, Operation *consumer,
Operation *producer, Value operand);
};
#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINE_ERROR_REPORTER_H_
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_triton_library(TritonGPUTransforms
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/OuterLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/PipelineErrorReporter.cpp
Pipeliner/TestPipelineAssignLatencies.cpp
Pipeliner/TestPipelineScheduleLoop.cpp
Pipeliner/SoftwarePipeliner.cpp
Expand Down
223 changes: 223 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineErrorReporter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Reporting error messages for scheduling errors.

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"
#include <cstdint>

#include "triton/Dialect/TritonGPU/Transforms/PipelineErrorReporter.h"

#define DEBUG_TYPE "triton-pipeline-error-reporter"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;
using namespace mlir::scf;

void PipelineErrorReporter::findRootDefiningOp(Operation *op,
unsigned int resultNumber) {
LDBG("findRootDefiningOp: " << *op << "\n from its " << resultNumber
<< "th result\n");

if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
// then branch.
{
auto operandFromThenBranch = ifOp.thenYield()->getOperand(resultNumber);
LDBG("operandFromThenBranch: " << operandFromThenBranch);
if (auto opResult = dyn_cast<OpResult>(operandFromThenBranch)) {
findRootDefiningOp(operandFromThenBranch.getDefiningOp(),
opResult.getResultNumber());
} else if (!dyn_cast<BlockArgument>(operandFromThenBranch)) {
rootDefiningOps.insert(operandFromThenBranch.getDefiningOp());
}
}
// else branch.
{
auto operandFromElseBranch = ifOp.elseYield()->getOperand(resultNumber);
LDBG("operandFromElseBranch: " << operandFromElseBranch);
if (auto opResult = dyn_cast<OpResult>(operandFromElseBranch)) {
findRootDefiningOp(opResult.getDefiningOp(),
opResult.getResultNumber());
} else if (!dyn_cast<BlockArgument>(operandFromElseBranch)) {
rootDefiningOps.insert(operandFromElseBranch.getDefiningOp());
}
}
} else {
rootDefiningOps.insert(op);
}
}

std::optional<Value>
PipelineErrorReporter::getBlockArgYieldValueFromForLoop(BlockArgument arg) {
if (arg.getOwner() != forOp.getBody())
return std::nullopt;
// Ignore induction variable.
if (arg.getArgNumber() == 0)
return std::nullopt;
return forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}

DenseSet<Operation *> findUsersInBlockHierarchy(BlockArgument arg,
Operation *consumerOp) {

DenseSet<Operation *> usersInBlockHierarchy;

for (Operation *user : arg.getUsers()) {
Operation *currentOp = user;
while (currentOp) {
if (currentOp == consumerOp) {
usersInBlockHierarchy.insert(user);
break;
}
currentOp = currentOp->getParentOp();
}
}

return usersInBlockHierarchy;
}

void PipelineErrorReporter::findRootSchedulingErrorLoopCarryDep(
Operation *consumer, Operation *producer, Value operand) {
LDBG("findRootSchedulingErrorLoopCarryDep: this operand is not ready at "
"the consumer: "
<< operand << "\n");

if (auto arg = dyn_cast<BlockArgument>(operand)) {

// This is a loop-carried dependency. Find which value yields the arg.
auto yieldValue = getBlockArgYieldValueFromForLoop(arg);
if (!yieldValue) {
LDBG("no yield value for arg " << arg << " -> BAIL");
return;
}

// first find the root consumer.
rootUserOps = std::move(findUsersInBlockHierarchy(arg, consumer));

assert(producer == yieldValue->getDefiningOp() &&
"producer should be the def of the yield value of operand");
// We repeat the process of computing the producer, because we need to
// know the result number of the producer, which is only available in the
// yield value.
LDBG("yield value (loop-carry): " << *yieldValue << "\n");
if (auto opResult = dyn_cast<OpResult>(*yieldValue)) {
findRootDefiningOp(producer, opResult.getResultNumber());
} else
rootDefiningOps.insert(producer);
}
}

std::optional<unsigned int> PipelineErrorReporter::findStage(Operation *op) {
auto it = stages.find(op);
if (it != stages.end()) {
return it->second;
}
return std::nullopt;
}

void printImplicitUse(Operation *op, InFlightDiagnostic &mainError) {
auto parentOpLoc = op->getParentOp()->getLoc();
if (isa<IfOp>(op->getParentOp())) {
// When an if branch yields a value, the original value is used implicitly
// when the condition is false. In this case, we don't have the source
// location of the implicit use. We can only attach a note to the if
// operation.
// TODO: we can report the location of the yield value in the if branch.
mainError.attachNote(parentOpLoc)
<< "Value is implicitly used here when the condition is false in "
"TTIR, because the variable is updated when the condition is "
"true.";
} else {
mainError.attachNote(parentOpLoc) << "Value is implicitly used here. ";
}
}

// Comparator for sorting FileLineColLoc operations
bool compareFileLineColLoc(Operation *a, Operation *b) {
auto locA = a->getLoc();
auto locB = b->getLoc();
if (!isa<FileLineColLoc>(locA))
return false;
if (!isa<FileLineColLoc>(locB))
return true;
auto fileLineLocA = dyn_cast<FileLineColLoc>(a->getLoc());
auto fileLineLocB = dyn_cast<FileLineColLoc>(b->getLoc());
if (!fileLineLocA || !fileLineLocB)
return false; // Should not happen if used correctly
if (fileLineLocA.getLine() != fileLineLocB.getLine())
return fileLineLocA.getLine() < fileLineLocB.getLine();
return fileLineLocA.getColumn() < fileLineLocB.getColumn();
}

void PipelineErrorReporter::printSchedulingError(int64_t distance,
Operation *consumer,
Operation *producer,
Value operand) {
LDBG("printSchedulingError: distance: " << distance << "\n");
const char *errorMessage =
"The software pipeliner failed due to a dependency conflict, resulting "
"in suboptimal loop performance.";
const char *likelyBuggyMessage = "This is likely to be a bug. Please "
"report it.";
// We only find the root defining ops for loop-carried dependencies.
// When distance is 0, we let the set of root defining ops to be empty.
if (distance > 0) {
findRootSchedulingErrorLoopCarryDep(consumer, producer, operand);
}
if (rootDefiningOps.empty()) {
// We failed to find the root defining ops. Whether the disntance is 0 or
// not, an empty set means we have some bugs in the pipeline expander. We
// should let the user help report the bug.
consumer->emitError() << errorMessage << " " << likelyBuggyMessage;
return;
}
// find the stage of the consumer and the producer.
auto consumerStage = findStage(consumer);
auto producerStage = findStage(producer);
if (!consumerStage || !producerStage) {
// We failed to find the stage of the consumer or the producer. This is
// likely to be a bug. We should let the user help report the bug.
consumer->emitError() << errorMessage << " " << likelyBuggyMessage;
return;
}
InFlightDiagnostic mainError = forOp->emitError() << errorMessage;
mainError.attachNote()
<< "The loop body is divided into " << numStages
<< " stages to optimize GPU I/O and computation resources. Different "
"parts of the loop body are computed in different iterations. "
<< "In iteration i, the update of the following variable is rescheduled "
"to execute in iteration i + "
<< *producerStage - *consumerStage
<< ". However, it must be updated before its use in iteration i.";
auto firstRootDefiningOp = *rootDefiningOps.begin();
auto firstRootDefiningOpName = firstRootDefiningOp->getName().getStringRef();

for (auto op : rootDefiningOps) {
mainError.attachNote(op->getLoc())
<< "The variable is updated here:";
}

auto firstRootUserOp = *rootUserOps.begin();
auto firstRootUserOpName = firstRootUserOp->getName().getStringRef();

// Sort rootUserOps using the custom comparator
std::vector<Operation *> sortedRootUserOps(rootUserOps.begin(),
rootUserOps.end());
std::sort(sortedRootUserOps.begin(), sortedRootUserOps.end(),
compareFileLineColLoc);
// Print sorted operations
for (auto op : sortedRootUserOps) {
auto loc = op->getLoc();
if (isa<UnknownLoc>(loc)) {
printImplicitUse(op, mainError);
} else {
// TODO: we can't find the variable name in the source code. Once we use
// FileLineColRange instead of FileLineColLoc, we can find the variable
// name and provide more detailed debug information.
mainError.attachNote(op->getLoc())
<< "The variable is used here:";
}
}
// TODO: we can also add more detailed debug information for more advanced
// users.
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"

#include "triton/Dialect/TritonGPU/Transforms/PipelineErrorReporter.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"

// FIXME: PipelineExpander should not depend on Triton-specific headers!
Expand Down Expand Up @@ -258,7 +259,9 @@ bool LoopPipelinerInternal::verifySchedule() {
continue;
int64_t producerCycle = it->second;
if (consumerCycle < producerCycle - numCylesPerIter * distance) {
consumer->emitError("operation scheduled before its operands");
PipelineErrorReporter errorReporter(forOp, maxStage + 1, stages);
errorReporter.printSchedulingError(distance, consumer, producer,
operand);
return false;
}
}
Expand Down
Loading

0 comments on commit 30ed20c

Please sign in to comment.