Skip to content

Commit

Permalink
[Alloc] Enhanced SharedMem Allocation for mutually exclusive but alia…
Browse files Browse the repository at this point in the history
…sed buffers (#337)

* [Alloc] Enhanced for mutually exclusive but aliased buffers

- Use disjoint alias analysis to minimize shared memory requirements

* * fix for allocation test

* * added test
* fixed mfma_enc printer

* * fixed test
  • Loading branch information
sjw36 authored Sep 26, 2023
1 parent 7af5e42 commit 4db99e0
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 25 deletions.
11 changes: 11 additions & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
template <typename T> class Interval {
public:
Interval() {}
Interval(T S) : Start(S), End(S+1) {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
Expand All @@ -44,6 +45,16 @@ template <typename T> class Interval {
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
bool adjacent(T Addr) const {
return Addr+1 == Start || Addr == End;
}
bool adjacent(const Interval &R) const {
return adjacent(R.Start) || adjacent(R.End-1);
}

Interval merge(const Interval &R) const {
return Interval(std::min(Start, R.Start), std::max(End, R.End));
}

private:
T Start = std::numeric_limits<T>::min();
Expand Down
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
static int getSharedSize(ModuleOp mod) {
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
if(!sharedAttr) {
return 0;
}
return sharedAttr.cast<IntegerAttr>().getInt();
}

}];

Expand Down
135 changes: 111 additions & 24 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,68 @@ class AllocationAnalysis {
using BufferT = Allocation::BufferT;

/// Value -> Liveness Range
using IntervalT = Interval<size_t>;
/// Use MapVector to ensure determinism.
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
using BufferRangeMapT = llvm::MapVector<BufferT *, IntervalT>;
/// Nodes -> Nodes
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;

/// Set of Liveness Intervals
class LivenessR : public SmallVector<IntervalT, 4> {
public:
LivenessR() = default;
LivenessR(const LivenessR &) = default;

/// Disjointness
bool isDisjoint() const {
if (size() < 2)
return false;
// sorted so the first OOB proves disjoint
auto maxId = (*this)[0].end();
for (auto rng : *this) {
if (rng.start() <= maxId) {
// adjoining
maxId = std::max(maxId, rng.end());
} else
return true;
}
return false;
}

void sort() {
llvm::sort(*this, [](const auto &lhs, const auto &rhs) {
return lhs.start() <= rhs.start();
});
}

bool addAdjacent(size_t id) {
bool isAdjacent = false;
for (auto &interval : *this) {
if (interval.adjacent(id)) {
isAdjacent = true;
interval = interval.merge(IntervalT(id));
}
}
return isAdjacent;
}

void add(size_t id) {
if (!addAdjacent(id))
push_back(IntervalT(id));
}
IntervalT unionize() const {
IntervalT res;
if (size()) {
res = front();
for (auto &I : *this)
res = res.merge(I);
}
return res;
}
};

typedef function_ref<LivenessR(Value value)> LivenessF;

void run() {
getValuesAndSizes();
resolveLiveness();
Expand Down Expand Up @@ -289,33 +346,55 @@ class AllocationAnalysis {

/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
void resolveExplicitBufferLiveness(LivenessF getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
auto ranges = getLiveness(value);
bufferRange[buffer] = ranges.unionize();
}
}

/// Extends the liveness range by unionizing the liveness range of the aliased
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
/// Only unionize adjacent live ranges to account for loop-carried buffers that
/// are mutually exclusive.
/// Example from stream pipeliner:
/// 3 %b0 = convert_layout %g0 -+
/// 4 %fr = for (.., %arg0 = %b0) { |
/// 5 %gn = load %pc |
/// 6 %bc = convert_layout %arg0 -+
/// 7 %v = add %bc, ...
/// 8 %bn = convert_layout %gn -+
/// 9 %pn = addptr %pc, %cst |
/// 10 } |
/// 11 %be = convert_layout %fr#1 -+
/// 12 %ve = add %be
void resolveAliasBufferLiveness(LivenessF getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
auto aranges = getLiveness(value);
bool disjoint = aranges.isDisjoint();
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
auto range = aranges[0];
if (bufferRange.count(buffer)) {
// Extend the allocated buffer's range
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
auto brange = bufferRange[buffer];
if (disjoint) {
// find adjacent/intersecting
for (auto arange : aranges) {
if (arange.adjacent(brange) ||
arange.intersects(brange))
brange = arange.merge(brange);
}
range = brange;
} else {
// Extend the allocated buffer's range
range = range.merge(brange);
}
}
bufferRange[buffer] = Interval(minId, maxId);
bufferRange[buffer] = range;
}
}
}
Expand Down Expand Up @@ -366,18 +445,13 @@ class AllocationAnalysis {
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
auto liveOperations = liveness.resolveLiveness(value);
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
LivenessR ranges;
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
ranges.add(operationId[liveOp]);
});
return Interval(minId, maxId);
ranges.sort();
return ranges;
};

resolveExplicitBufferLiveness(getValueLivenessRange);
Expand Down Expand Up @@ -432,9 +506,9 @@ class AllocationAnalysis {
// If the available triple's range is less than a given buffer range,
// we won't know if there has been an overlap without using graph coloring.
// Start -> Liveness Range
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
using TripleMapT = std::multimap<size_t, IntervalT>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
tripleMap.insert(std::make_pair(0, IntervalT()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
Expand Down Expand Up @@ -542,6 +616,19 @@ class AllocationAnalysis {
}
}

void dump() const {
llvm::outs() << "DUMP: " << "\n";
for (auto bufferIter : bufferRange) {

llvm::outs() << "ID= " << bufferIter.first->id << "\n";
// llvm::outs() << " Kind= " << kind << "\n";
llvm::outs() << " Size= " << bufferIter.first->size << "\n";
llvm::outs() << " Offs= " << bufferIter.first->offset << "\n";
llvm::outs() << " -> " << bufferIter.second.start() << "\n";
llvm::outs() << " -> " << bufferIter.second.end() << "\n";
}
}

private:
Operation *operation;
Allocation::FuncAllocMapT *funcAllocMap;
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
if ((mmaParent && mmaParent.isAmpere()) || getParent().isa<MfmaEncodingAttr>())
printer << ", kWidth = " << getKWidth();
printer << "}>";
}
Expand Down Expand Up @@ -1221,6 +1221,9 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
os << "mma";
return AliasResult::FinalAlias;
} else if (attr.isa<MfmaEncodingAttr>()) {
os << "mfma";
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
os << "shared";
return AliasResult::FinalAlias;
Expand Down
Loading

0 comments on commit 4db99e0

Please sign in to comment.