Skip to content

Commit

Permalink
[LAYOUTS] Make operator* associative and dimension-order-preserving (#…
Browse files Browse the repository at this point in the history
…5928)

To do so, we compute the supremum on the POSET of ordered strings 
without repetition ordered via the inclusion and 
error out if it does not exist. We break ties to the left.

We do this as this is the natural order for the output dims.
  • Loading branch information
lezcano authored Feb 15, 2025
1 parent f5bd3fd commit 3463719
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 42 deletions.
6 changes: 6 additions & 0 deletions include/triton/Tools/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

// Compute the supremum of two lists.
// Error out if the supremum does not exist (e.g. [a, b] and [b, a]).
// If the supremum is not unique, we return the first list first
// (e.g. [a, b], [a, c] -> [a, b, c]).
SmallVector<StringAttr> supremum(const SmallVector<StringAttr> &x,
const SmallVector<StringAttr> &y);
} // namespace mlir::triton

#endif // TRITON_TOOLS_LAYOUTUTILS_H
26 changes: 19 additions & 7 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,13 @@ class LinearLayout {
// from a larger one.
[[nodiscard]] LinearLayout concatOuts(const LinearLayout &other) const;

// Creates a new layout which, roughly speaking, is equivalent to one where
// every element of the `outer` layout is replaced by a full instance of the
// `inner` layout.
// Computes the direct sum of two layouts.
// https://en.wikipedia.org/wiki/Direct_sum#Direct_sum_of_matrices
//
// Roughly speaking, the first layout acts on the first part of the input
// dimensions, and the second layout acts on the second part.
// In other words, it's the generalisation of concatenation of the inputs
// to linear maps.
//
// Examples:
//
Expand All @@ -572,19 +576,27 @@ class LinearLayout {
//
// - identity1D(4, "i", "o") * identity1D(2, "i", "o") ==
// identity1D(8, "i", "o")
// The output matrix is [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
//
// - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4
// for x in [0,8).
// The output matrix is [[1, 0, 0], [0, 1, 0], [0, 0, 0]]
//
// - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2
// for x in [0,8).
//
// The output matrix is [[0, 0, 0], [0, 1, 0], [0, 0, 1]]

// - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") =>
// L(x) = (x % 4, x / 4) for x in [0,32).
// The output dims are ("o1", "o2") in that order.
//
// If the input (or output) dims of the layouts are not the same, we take
// the supremum of the two ordered lists with the inclusion, respecting the
// order. If multiple suprema exist, we bias towards the first list.
// e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c]
// sup([a, b], [b, a]) = error! Supremum does not exist.
//
// Notice that this operation is not commutative. It's also not associative.
// TODO(jlebar): Can I modify the definition to make it associative? Pretty
// confusing if not. If I can't, add an example.
// Notice that this operation is not commutative, but it is associative.
//
// Requires: Any in/out dimensions which are in both outer and inner appear in
// the same relative order.
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2880,9 +2880,6 @@ struct TritonGPUInferLayoutInterface
if (fwdInference) {
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
newLl = split * ll;
// FIXME!!!!
// operator* transposes the output dimensions??!! WTF
newLl = newLl.transposeOuts(outDims);
} else {
// TODO This requires a division algorithm!
// Implement manually ll.divideLeft(split)
Expand Down
63 changes: 63 additions & 0 deletions lib/Tools/LayoutUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,67 @@ LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
return ret;
}

// Compute the supremum of two lists.
// If the supremum is not unique, we return the first list first
// Error out if the supremum does not exist
// e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c]
// sup([a, b], [b, a]) = error! Supremum does not exist.
SmallVector<StringAttr> supremum(const SmallVector<StringAttr> &x,
const SmallVector<StringAttr> &y) {
llvm::SetVector<StringAttr> result;
DenseMap<StringAttr, int> posX, posY;
for (auto [idx, elem] : llvm::enumerate(x))
posX[elem] = idx;
for (auto [idx, elem] : llvm::enumerate(y))
posY[elem] = idx;
int i = 0, j = 0;
const int INF = std::numeric_limits<int>::max();
while (i < x.size() || j < y.size()) {
while (i < x.size() && result.contains(x[i]))
++i;
while (j < y.size() && result.contains(y[j]))
++j;
if (i >= x.size() && j >= y.size())
break;
if (i < x.size() && j < y.size() && x[i] == y[j]) {
if (posY[x[i]] < j)
llvm_unreachable("Supremum does not exist");
result.insert(x[i]);
++i, ++j;
continue;
}
int candX = INF, candY = INF;
if (i < x.size()) {
if (posY.count(x[i]) && posY[x[i]] >= j)
candX = posY[x[i]];
}
if (j < y.size()) {
if (posX.count(y[j]) && posX[y[j]] >= i)
candY = posX[y[j]];
}
if (i < x.size() && candX == INF) {
result.insert(x[i]);
++i;
continue;
}
if (j < y.size() && candY == INF) {
result.insert(y[j]);
++j;
continue;
}
if (candX <= candY) {
if (posY[x[i]] < j)
llvm_unreachable("Supremum does not exist");
result.insert(x[i]);
++i;
} else {
if (posX[y[j]] < i)
llvm_unreachable("Supremum does not exist");
result.insert(y[j]);
++j;
}
}
return to_vector(result);
}

} // namespace mlir::triton
42 changes: 10 additions & 32 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include "third_party/f2reduce/f2reduce.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"

Expand Down Expand Up @@ -154,36 +156,6 @@ void assertDimsSubsetIgnoringOrder(T &&small, U &&big) {
triton::join(big, ", ") + "]");
}
}

// Check that elements common to both aDims and bDims
// appear in the same relative order.
template <typename T, typename U>
void assertCommonDimsSameOrder(T &&aDims, U &&bDims) {
SmallDenseSet<StringAttr> aDimsSet(aDims.begin(), aDims.end());
SmallDenseSet<StringAttr> bDimsSet(bDims.begin(), bDims.end());

std::vector<StringAttr> aCommonDims;
for (StringAttr dim : aDims) {
if (bDimsSet.contains(dim)) {
aCommonDims.push_back(dim);
}
}

std::vector<StringAttr> bCommonDims;
for (StringAttr dim : bDims) {
if (aDimsSet.contains(dim)) {
bCommonDims.push_back(dim);
}
}

if (aCommonDims != bCommonDims) {
llvm::report_fatal_error("All a/b dimensions common to both layouts "
"must appear in the same relative order, but they "
"don't.\na:" +
Twine(triton::join(aDims, ", ")) +
"\nb: " + triton::join(bDims, ", "));
}
}
} // anonymous namespace

/*static*/ std::optional<LinearLayout>
Expand Down Expand Up @@ -590,14 +562,20 @@ LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const {

LinearLayout operator*(LinearLayout inner, LinearLayout outer) {
// Check that dims common to outer and inner have the same relative order.
assertCommonDimsSameOrder(inner.getOutDimNames(), outer.getOutDimNames());
assertCommonDimsSameOrder(inner.getInDimNames(), outer.getInDimNames());
auto inDims = supremum(llvm::to_vector(inner.getInDimNames()),
llvm::to_vector(outer.getInDimNames()));
auto outDims = supremum(llvm::to_vector(inner.getOutDimNames()),
llvm::to_vector(outer.getOutDimNames()));

// Get the sizeLog2 of all input and output dimensions we're going to
// consider, in order. `inner` is more minor, so its dimensions come
// first.
llvm::MapVector<StringAttr, int32_t> inDimSizesLog2;
llvm::MapVector<StringAttr, int32_t> outDimSizesLog2;
for (const auto &dim : inDims)
inDimSizesLog2.insert({dim, 0});
for (const auto &dim : outDims)
outDimSizesLog2.insert({dim, 0});
for (const auto &layout : {inner, outer}) {
for (StringAttr inDim : layout.getInDimNames()) {
inDimSizesLog2[inDim] += layout.getInDimSizeLog2(inDim);
Expand Down
64 changes: 64 additions & 0 deletions unittest/Tools/LinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,70 @@ TEST_F(LinearLayoutTest, BlackwellMixedPrecisionDotScaledSMEMSwizzled) {
}
}

static SmallVector<StringAttr> makeList(MLIRContext *ctx,
llvm::ArrayRef<llvm::StringRef> list) {
SmallVector<StringAttr> ret;
for (auto s : list)
ret.push_back(StringAttr::get(ctx, s));
return ret;
}

TEST(SupremumTest, IdenticalLists) {
MLIRContext ctx;
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b", "c"});
SmallVector<StringAttr> y = makeList(&ctx, {"a", "b", "c"});
EXPECT_EQ(supremum(x, y), x);
}

TEST(SupremumTest, NonUniqueSupremumFirstListPriority) {
MLIRContext ctx;
// sup([a, b], [a, c]) should yield [a, b, c]
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b"});
SmallVector<StringAttr> y = makeList(&ctx, {"a", "c"});
EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c"}));
}

TEST(SupremumTest, NonUniqueSupremumAlternate) {
MLIRContext ctx;
// sup([a, b], [b, c]) should yield [a, b, c]
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b"});
SmallVector<StringAttr> y = makeList(&ctx, {"b", "c"});
EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c"}));
}

TEST(SupremumTest, DifferentLengths) {
MLIRContext ctx;
// sup([a, b, c], [a, d]) should yield [a, b, c, d]
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b", "c"});
SmallVector<StringAttr> y = makeList(&ctx, {"a", "d"});
EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c", "d"}));
}

TEST(SupremumTest, SupremumEmptyLists) {
MLIRContext ctx;
SmallVector<StringAttr> x;
SmallVector<StringAttr> y;
EXPECT_TRUE(supremum(x, y).empty());
}

TEST(SupremumTest, OneEmptyList) {
MLIRContext ctx;
// sup([a, b], []) should yield [a, b]
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b"});
SmallVector<StringAttr> y;
EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b"}));
}

#ifdef LLVM_ENABLE_ASSERTIONS
TEST(SupremumTest, ErrorOnInconsistentOrder) {
MLIRContext ctx;
// sup([a, b], [b, a]) has no consistent ordering so it should trigger
// llvm_unreachable.
SmallVector<StringAttr> x = makeList(&ctx, {"a", "b"});
SmallVector<StringAttr> y = makeList(&ctx, {"b", "a"});
ASSERT_DEATH({ supremum(x, y); }, "Supremum does not exist");
}
#endif
} // anonymous namespace
} // namespace mlir::triton

Expand Down

0 comments on commit 3463719

Please sign in to comment.