Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIPELINER] Refactor pipeliner lowering. #5989

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b4255bd
Remove outer loop pipelining transformation
pawelszczerbuk Jan 29, 2025
77d9e32
Merge branch 'main' into pawel/remove_outer_loop_pipe
pawelszczerbuk Jan 30, 2025
a8279ad
Starting to work on lowering loads
pawelszczerbuk Jan 31, 2025
bc5afdf
.
pawelszczerbuk Feb 5, 2025
db1ed00
Merge branch 'main' into pawel/refactor_pipe_lowering
pawelszczerbuk Feb 5, 2025
3528729
Working on lowering loads
pawelszczerbuk Feb 5, 2025
b74e9f4
Merge branch 'main' into pawel/refactor_pipe_lowering
pawelszczerbuk Feb 5, 2025
5f7a660
Merge branch 'main' into pawel/refactor_pipe_lowering
pawelszczerbuk Feb 5, 2025
ad2ece4
Working on createAsyncCopy
pawelszczerbuk Feb 6, 2025
e92f286
.
pawelszczerbuk Feb 10, 2025
0510fbd
Somewhat working version for simple loads, added tests
pawelszczerbuk Feb 11, 2025
6a1a106
.
pawelszczerbuk Feb 11, 2025
5b46c26
Some more tests, some more fixes
pawelszczerbuk Feb 12, 2025
109f1fa
.
pawelszczerbuk Feb 13, 2025
bece86e
Tests and fixes
pawelszczerbuk Feb 13, 2025
bd7dc10
Merge branch 'main' into pawel/refactor_pipe_lowering
pawelszczerbuk Feb 13, 2025
ac3c15b
Putting transformations in separate files, calling them from Software…
pawelszczerbuk Feb 14, 2025
9d84de0
typo
pawelszczerbuk Feb 14, 2025
6b4e26e
Removing LoopScheduling pass
pawelszczerbuk Feb 14, 2025
bf953ea
Adding perf remarks, cleaning up the comments
pawelszczerbuk Feb 14, 2025
5b7dafe
Update comments, remove dead code
pawelszczerbuk Feb 14, 2025
9f5fd47
Removing more dead code. AssignLatencies always serializes latencies …
pawelszczerbuk Feb 14, 2025
1db956c
Merge branch 'main' into pawel/refactor_pipe_lowering
pawelszczerbuk Feb 14, 2025
8c6c855
Adding tests for assymetric loads and for dependent loads. Fixing bugs
pawelszczerbuk Feb 15, 2025
ae6110c
Allocate additional buffer for wgmma pipelining
pawelszczerbuk Feb 18, 2025
73ab770
Properly handling cases with load users in next iteration and across …
pawelszczerbuk Feb 19, 2025
1603717
Fix for crash in tests, perf of LUT loads confirmed to be on par with…
pawelszczerbuk Feb 19, 2025
4a83f72
Merge branch 'main' into pawel/pawel/refactor_pipeline_lowering2
pawelszczerbuk Feb 19, 2025
5bfcbb0
TMA loads and gather lowering implemented with tests
pawelszczerbuk Feb 20, 2025
f9b86b7
Lowering of TMA descriptors
pawelszczerbuk Feb 20, 2025
0a68dce
Introducing dumps after sub-passes
pawelszczerbuk Feb 21, 2025
f1e1aa3
Add wait 0 after the loop
pawelszczerbuk Feb 21, 2025
0dc449f
Cleaning out attributes after the pipelining
pawelszczerbuk Feb 21, 2025
0bbb0d3
Enabling wgmma pipelining, stab at proper lowering of multibuffers fo…
pawelszczerbuk Feb 21, 2025
0b47299
All the lit tests are passing
pawelszczerbuk Feb 21, 2025
bb91139
Tests for proper lowering of mmav5 scaled
pawelszczerbuk Feb 21, 2025
9a3e1d6
.
pawelszczerbuk Feb 21, 2025
29a6f5e
Removing MatmulLoopPipeline
pawelszczerbuk Feb 21, 2025
bc7978b
Merge branch 'main' into pawel/pawel/refactor_pipeline_lowering2
pawelszczerbuk Feb 22, 2025
84f4964
Adding missing file
pawelszczerbuk Feb 22, 2025
6676a0b
Merge branch 'main' into pawel/pawel/refactor_pipeline_lowering2
pawelszczerbuk Feb 25, 2025
aba01a0
Merge branch 'main' into pawel/pawel/refactor_pipeline_lowering2
pawelszczerbuk Feb 25, 2025
2b61ebe
More aggressive asyncWaitOp combining, removing incorrect assert from…
pawelszczerbuk Feb 25, 2025
e790c33
PR comments
pawelszczerbuk Feb 25, 2025
54d450c
Merge branch 'main' into pawel/pawel/refactor_pipeline_lowering2
pawelszczerbuk Feb 25, 2025
2a69250
PR comments
pawelszczerbuk Feb 25, 2025
945d8c3
PR comments
pawelszczerbuk Feb 25, 2025
86331b6
Change the way pipelining test checks number of stages to be more rel…
pawelszczerbuk Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ dev-install: dev-install-requires dev-install-triton

.PHONY: golden-samples
golden-samples: triton-opt
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
36 changes: 19 additions & 17 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">
"number of pipeline stages">,
Option<"dumpIntermediateSteps", "dump-intermediate-steps",
"bool", /*default*/"false",
"Dump intermediate steps">
];
}

Expand All @@ -45,7 +48,7 @@ def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-
let summary = "test assigning latencies to interesting ops ahead of pipelining";

let description = [{
This is a test pass that tests `assignLatencies` method of `TritonGPULoopScheduling`.
This is a test pass that tests `assignLatencies` method of `TritonGPUPipeline`.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
Expand All @@ -64,7 +67,20 @@ def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-l
let summary = "test scheduling a loop for software pipelining";

let description = [{
This is a test pass that tests `scheduleLoop` method of `TritonGPULoopScheduling`.
This is a test pass that tests `scheduleLoop` method of `TritonGPUPipeline`.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}

def TritonGPUTestPipelineLowerLoop : Pass<"tritongpu-test-pipeline-lower-loop", "mlir::ModuleOp"> {
let summary = "test lowering a loop for software pipelining";

let description = [{
This is a test pass that tests `lowerLoop` method of `TritonGPUPipeline`.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
Expand Down Expand Up @@ -254,20 +270,6 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
"mlir::triton::TritonDialect"];
}

def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> {
let summary = "Generate loop scheduling for SWP";

let description = "This pass sets up stages and clustering for software pipelining.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">
];
}

def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
let summary = "Improve coalescing for async global to local copies";

Expand Down
39 changes: 22 additions & 17 deletions include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <optional>
#include <utility>
#include <vector>
Expand All @@ -14,25 +15,14 @@ static const char *kDisallowAccMultiBufferAttrName =
"tt.disallow_acc_multi_buffer";
static const char *kLoopStageAttrName = "loop.stage";
static const char *kLoopClusterAttrName = "loop.cluster";
static const char *kLatencyAttrName = "tt.latency";

bool loopHasDistGreaterThanOne(scf::ForOp forOp);
bool isOuterLoop(scf::ForOp forOp);

/// Function to mask operations during scheduling.
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);

/// Collect ssa dependencies of `op` in `deps`. if `includeArg` is true,
/// continue looking through loop block arguments.
void addDep(Operation *op, DenseSet<Operation *> &deps, bool includeArg = true,
DenseSet<Operation *> *filter = nullptr);

/// Add operations from `forOp` into a pipeline schedule with the the given
/// `stage` when filter is true. This will add operation in the original loop
/// order.
void addOps(scf::ForOp forOp, int stage,
std::vector<std::pair<Operation *, unsigned>> &schedule,
std::function<bool(Operation *)> filter);

/// Replace all uses of `oldUse` with `val` and propagate the type if needed.
/// This is useful when we need to change a memory descriptor from immutable to
/// mutable.
Expand All @@ -50,11 +40,26 @@ void visitNestedOperands(Operation *op, function_ref<void(Value)> visitor);
/// of `op`.
SetVector<Value> getNestedOperands(Operation *op);

// Return the minClusterId and maxClusterId for the given ForOp.
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
std::pair<int, int> getStageCluster(Operation *op);
std::optional<std::pair<int, int>> maybeGetStageCluster(Operation *op);
void setStageCluster(Operation *op, int stage, int cluster);
// Return maxumum length of the vectorized copy between registers and shared
// memory for the given tensor type and shared encoding.
int getCopyVecBytes(RankedTensorType registerTy,
gpu::SharedEncodingTrait sharedEnc);

// Serialize the latencies of the operations in the loops into the latency
// attribute.
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);

// Deserialize the latencies of the operations in the loops from the attribute.
DenseMap<Operation *, int> deserializeLatencies(ModuleOp module);

// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
// single buffer slice (leading dimension equal to 1), at the given index.
Value createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
Value createSingleBufferView(OpBuilder &builder, Value alloc, int idx);

// Create an allocation and init the mbarriers.
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers);

} // namespace triton
} // namespace mlir

Expand Down
44 changes: 33 additions & 11 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ namespace gpu {

/// Discover operations that should become async and assign latencies to them
/// based on the numStages value provided by the user.
DenseMap<Operation *, int> assignLatencies(ModuleOp forOp, int numStages);
void assignLatencies(ModuleOp moduleOp, int numStages);

/// Schedule the loop based on the latencies assigned to the operations.
void scheduleLoop(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency);
/// Schedule the loops based on the latencies assigned to the operations.
void scheduleLoops(ModuleOp moduleOp);

/// Lower the loops to prepare them for pipeline expansion.
void lowerLoops(ModuleOp moduleOp);

}; // namespace gpu

Expand All @@ -34,11 +36,10 @@ bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Pipeline the Tensor Core Gen 05 MMA ops in `forOps` with `numStages` stages.
/// This will pre-process the loops, lowering the ops related to TG Gen5 MMA,
/// and then pipeline the loops using expander.
void pipelineTC05MMALoops(ModuleOp module,
const SmallVector<scf::ForOp> &forOps, int numStages,
/// Pipeline the Tensor Core Gen 05 MMA ops in the module with `numStages`
/// stages. This will pre-process the loops, lowering the ops related to TG Gen5
/// MMA, and then pipeline the loops using expander.
void pipelineTC05MMALoops(ModuleOp module, int numStages,
bool disableExpander = false);

/// Pipeline the TMA stores in the loop.
Expand All @@ -64,9 +65,12 @@ class CoarseSchedule {

public:
using iterator = decltype(orderClusters)::iterator;
using const_iterator = decltype(orderClusters)::const_iterator;
ClusterList() = default;
iterator begin() { return orderClusters.begin(); }
const_iterator begin() const { return orderClusters.begin(); }
iterator end() { return orderClusters.end(); }
const_iterator end() const { return orderClusters.end(); }
size_t size() { return orderClusters.size(); }
iterator newAtBack() {
orderClusters.push_back(orderClusters.size());
Expand All @@ -86,16 +90,31 @@ class CoarseSchedule {
}
return ret;
}

bool isBefore(iterator a, iterator b) const {
for (auto it = begin(); it != end(); ++it) {
if (it == a)
return true;
if (it == b)
return false;
}
llvm::report_fatal_error(
"One or both clusters not found in clusters list!");
}
};

CoarseSchedule() = default;
CoarseSchedule(int numStages) : numStages(numStages) {}
int numStages;
ClusterList clusters;
using Cluster = decltype(clusters)::iterator;

DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;

void setNumStages(int numStages) { this->numStages = numStages; }
int getNumStages() { return numStages; }

void insert(Operation *op, int stage, Cluster cluster) {
assert(stage < numStages && "Invalid stage");
opToStageAndCluster[op] = {stage, cluster};
}

Expand Down Expand Up @@ -133,9 +152,12 @@ class CoarseSchedule {
// Set <stage, cluster> based on CoarseSchedule.
void serialize(scf::ForOp &forOp);
// Create a CoarseSchedule based on forOp's <stage, cluster>.
void deSerialize(scf::ForOp &forOp);
LogicalResult deSerialize(scf::ForOp &forOp);

LLVM_DUMP_METHOD void dump();

private:
int numStages = 0;
};

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ add_triton_library(TritonGPUTransforms
FuseNestedLoops.cpp
CombineTensorSelectAndIf.cpp
DecomposeScaledBlocked.cpp
LoopScheduling.cpp
ReduceDataDuplication.cpp
OptimizeAccumulatorInit.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Pipeliner/AssignLatencies.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/LowerLoops.cpp
Pipeliner/ScheduleLoops.cpp
Pipeliner/WGMMAPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/TestPipelineAssignLatencies.cpp
Pipeliner/TestPipelineScheduleLoop.cpp
Pipeliner/TestPipelineLowerLoop.cpp
Pipeliner/SoftwarePipeliner.cpp
Pipeliner/TC05MMAPipeline.cpp
Pipeliner/TMAStoresPipeline.cpp
Expand Down
21 changes: 4 additions & 17 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "triton-pipeline-schedule"
#define DEBUG_TYPE "triton-loop-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

Expand Down Expand Up @@ -65,17 +65,6 @@ bool isSmallLoad(tt::LoadOp loadOp,
return width < 32;
}

int getCopyVecBytes(RankedTensorType registerTy,
ttg::SharedEncodingTrait sharedEnc) {
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
registerTy.getEncoding());
auto sharedLayout =
triton::gpu::toLinearLayout(registerTy.getShape(), sharedEnc);
auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
const int vecElems = regToSharedLayout.getNumConsecutiveInOut();
return vecElems * registerTy.getElementTypeBitWidth() / 8;
}

bool isPipeliningBeneficial(Operation *op, Operation *finalUser,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
Expand Down Expand Up @@ -233,8 +222,7 @@ void assignUserProvidedLatencies(scf::ForOp forOp,
// on the requested number of stages assign the latencies in a way that
// cover all the stages with the sum of latencies in the chain from the first
// load to the final dot op.
DenseMap<Operation *, int> assignLatencies(ModuleOp moduleOp,
int defaultNumStages) {
void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
auto getNumStagesOrDefault = [defaultNumStages](scf::ForOp forOp) -> int {
// Use the attribute attached to the loop if it exists otherwise use the
// global control.
Expand All @@ -252,7 +240,7 @@ DenseMap<Operation *, int> assignLatencies(ModuleOp moduleOp,
loops.push_back(forOp);
});
if (loops.empty())
return DenseMap<Operation *, int>();
return;

DenseMap<Operation *, int> opLatency;
for (auto forOp : loops) {
Expand Down Expand Up @@ -291,9 +279,8 @@ DenseMap<Operation *, int> assignLatencies(ModuleOp moduleOp,
opLatency[loadOp] = loadLatency;
}
}
return opLatency;
serializeLatencies(moduleOp, opLatency);
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Loading
Loading