diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 9b9d9dace72f..02345d44eec2 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -29,6 +29,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, template class Interval { public: Interval() {} + Interval(T S) : Start(S), End(S+1) {} Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } T start() const { return Start; } T end() const { return End; } @@ -44,6 +45,16 @@ template class Interval { bool operator<(const Interval &R) const { return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); } + bool adjacent(T Addr) const { + return Addr+1 == Start || Addr == End; + } + bool adjacent(const Interval &R) const { + return adjacent(R.Start) || adjacent(R.End-1); + } + + Interval merge(const Interval &R) const { + return Interval(std::min(Start, R.Start), std::max(End, R.End)); + } private: T Start = std::numeric_limits::min(); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index e35ee2b576c3..47906a34d481 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -38,6 +38,13 @@ def TritonGPU_Dialect : Dialect { } return threadsPerWarp.cast().getInt(); } + static int getSharedSize(ModuleOp mod) { + Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared"); + if(!sharedAttr) { + return 0; + } + return sharedAttr.cast().getInt(); + } }]; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 279b79cb8932..ccd6b4d7a7e8 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -137,11 +137,68 @@ class AllocationAnalysis { using BufferT = Allocation::BufferT; /// Value -> Liveness Range + using IntervalT = Interval; /// Use MapVector to ensure determinism. - using BufferRangeMapT = llvm::MapVector>; + using BufferRangeMapT = llvm::MapVector; /// Nodes -> Nodes using GraphT = DenseMap>; + /// Set of Liveness Intervals + class LivenessR : public SmallVector { + public: + LivenessR() = default; + LivenessR(const LivenessR &) = default; + + /// Disjointness + bool isDisjoint() const { + if (size() < 2) + return false; + // sorted so the first OOB proves disjoint + auto maxId = (*this)[0].end(); + for (auto rng : *this) { + if (rng.start() <= maxId) { + // adjoining + maxId = std::max(maxId, rng.end()); + } else + return true; + } + return false; + } + + void sort() { + llvm::sort(*this, [](const auto &lhs, const auto &rhs) { + return lhs.start() <= rhs.start(); + }); + } + + bool addAdjacent(size_t id) { + bool isAdjacent = false; + for (auto &interval : *this) { + if (interval.adjacent(id)) { + isAdjacent = true; + interval = interval.merge(IntervalT(id)); + } + } + return isAdjacent; + } + + void add(size_t id) { + if (!addAdjacent(id)) + push_back(IntervalT(id)); + } + IntervalT unionize() const { + IntervalT res; + if (size()) { + res = front(); + for (auto &I : *this) + res = res.merge(I); + } + return res; + } + }; + + typedef function_ref LivenessF; + void run() { getValuesAndSizes(); resolveLiveness(); @@ -289,33 +346,55 @@ class AllocationAnalysis { /// Computes the liveness range of the allocated value. /// Each buffer is allocated only once. - void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { + void resolveExplicitBufferLiveness(LivenessF getLiveness) { for (auto valueBufferIter : allocation->valueBuffer) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; - bufferRange[buffer] = getLiveness(value); + auto ranges = getLiveness(value); + bufferRange[buffer] = ranges.unionize(); } } /// Extends the liveness range by unionizing the liveness range of the aliased /// values because each allocated buffer could be an alias of others, if block /// arguments are involved. - void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { + /// Only unionize adjacent live ranges to account for loop-carried buffers that + /// are mutually exclusive. + /// Example from stream pipeliner: + /// 3 %b0 = convert_layout %g0 -+ + /// 4 %fr = for (.., %arg0 = %b0) { | + /// 5 %gn = load %pc | + /// 6 %bc = convert_layout %arg0 -+ + /// 7 %v = add %bc, ... + /// 8 %bn = convert_layout %gn -+ + /// 9 %pn = addptr %pc, %cst | + /// 10 } | + /// 11 %be = convert_layout %fr#1 -+ + /// 12 %ve = add %be + void resolveAliasBufferLiveness(LivenessF getLiveness) { for (auto aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; - auto range = getLiveness(value); + auto aranges = getLiveness(value); + bool disjoint = aranges.isDisjoint(); for (auto *buffer : buffers) { - auto minId = range.start(); - auto maxId = range.end(); + auto range = aranges[0]; if (bufferRange.count(buffer)) { - // Extend the allocated buffer's range - minId = std::min(minId, bufferRange[buffer].start()); - maxId = std::max(maxId, bufferRange[buffer].end()); + auto brange = bufferRange[buffer]; + if (disjoint) { + // find adjacent/intersecting + for (auto arange : aranges) { + if (arange.adjacent(brange) || + arange.intersects(brange)) + brange = arange.merge(brange); + } + range = brange; + } else { + // Extend the allocated buffer's range + range = range.merge(brange); + } } - bufferRange[buffer] = Interval(minId, maxId); + bufferRange[buffer] = range; } } } @@ -366,18 +445,13 @@ class AllocationAnalysis { Liveness liveness(operation); auto getValueLivenessRange = [&](Value value) { auto liveOperations = liveness.resolveLiveness(value); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); + LivenessR ranges; std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; - } + ranges.add(operationId[liveOp]); }); - return Interval(minId, maxId); + ranges.sort(); + return ranges; }; resolveExplicitBufferLiveness(getValueLivenessRange); @@ -432,9 +506,9 @@ class AllocationAnalysis { // If the available triple's range is less than a given buffer range, // we won't know if there has been an overlap without using graph coloring. // Start -> Liveness Range - using TripleMapT = std::multimap>; + using TripleMapT = std::multimap; TripleMapT tripleMap; - tripleMap.insert(std::make_pair(0, Interval())); + tripleMap.insert(std::make_pair(0, IntervalT())); SmallVector xBuffers = buffers; while (!xBuffers.empty()) { auto tripleIt = tripleMap.begin(); @@ -542,6 +616,19 @@ class AllocationAnalysis { } } + void dump() const { + llvm::outs() << "DUMP: " << "\n"; + for (auto bufferIter : bufferRange) { + + llvm::outs() << "ID= " << bufferIter.first->id << "\n"; + // llvm::outs() << " Kind= " << kind << "\n"; + llvm::outs() << " Size= " << bufferIter.first->size << "\n"; + llvm::outs() << " Offs= " << bufferIter.first->offset << "\n"; + llvm::outs() << " -> " << bufferIter.second.start() << "\n"; + llvm::outs() << " -> " << bufferIter.second.end() << "\n"; + } + } + private: Operation *operation; Allocation::FuncAllocMapT *funcAllocMap; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5632b0f6b0ff..95524504f4ce 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1095,7 +1095,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { auto mmaParent = getParent().dyn_cast(); printer << "<{" << "opIdx = " << getOpIdx() << ", parent = " << getParent(); - if (mmaParent && mmaParent.isAmpere()) + if ((mmaParent && mmaParent.isAmpere()) || getParent().isa()) printer << ", kWidth = " << getKWidth(); printer << "}>"; } @@ -1221,6 +1221,9 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; return AliasResult::FinalAlias; + } else if (attr.isa()) { + os << "mfma"; + return AliasResult::FinalAlias; } else if (auto sharedAttr = attr.dyn_cast()) { os << "shared"; return AliasResult::FinalAlias; diff --git a/test/Conversion/minimize_alloc.mlir b/test/Conversion/minimize_alloc.mlir new file mode 100644 index 000000000000..f663d7f9247e --- /dev/null +++ b/test/Conversion/minimize_alloc.mlir @@ -0,0 +1,116 @@ +// RUN: triton-opt --convert-triton-gpu-to-llvm %s | FileCheck %s + +// CHECK: module attributes {{{.*}}, triton_gpu.shared = 9216 : i32 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> + %cst_0 = arith.constant dense<32> : tensor<64x32xi32, #blocked> + %c31_i32 = arith.constant 31 : i32 + %c63_i32 = arith.constant 63 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c4_i32 = arith.constant 4 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c4_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c4_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = "triton_gpu.cmpi"(%8, %c4_i32) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %10 = arith.select %9, %8, %c4_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = arith.addi %19, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.addi %20, %18 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %23 = arith.muli %14, %c64_i32 : i32 + %24 = tt.splat %23 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %25 = arith.addi %24, %17 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked> + %27 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %28 = tt.splat %arg6 : (i32) -> tensor<64x1xi32, #blocked> + %29 = arith.muli %26, %28 : tensor<64x1xi32, #blocked> + %30 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked> + %34 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked>) -> tensor<64x32x!tt.ptr, #blocked> + %35 = tt.broadcast %33 : (tensor<1x32xi32, #blocked>) -> tensor<64x32xi32, #blocked> + %36 = tt.addptr %34, %35 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %37 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1> + %39 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1> + %40 = arith.muli %38, %39 : tensor<32x1xi32, #blocked1> + %41 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked1> + %42 = tt.addptr %41, %40 : tensor<32x1x!tt.ptr, #blocked1>, tensor<32x1xi32, #blocked1> + %43 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %44 = tt.broadcast %42 : (tensor<32x1x!tt.ptr, #blocked1>) -> tensor<32x64x!tt.ptr, #blocked1> + %45 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<32x64xi32, #blocked1> + %46 = tt.addptr %44, %45 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %47 = arith.addi %arg5, %c31_i32 : i32 + %48 = arith.divsi %47, %c32_i32 : i32 + %49 = arith.muli %arg7, %c32_i32 : i32 + %50 = tt.splat %49 : (i32) -> tensor<32x64xi32, #blocked1> + %51 = tt.load %36 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked> + %52 = triton_gpu.convert_layout %51 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared> + %53 = tt.load %46 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1> + %54 = triton_gpu.convert_layout %53 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1> + %55 = tt.addptr %36, %cst_0 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %56 = tt.addptr %46, %50 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %57 = arith.subi %48, %c1_i32 : i32 + cf.br ^bb1(%c0_i32, %cst, %52, %54, %55, %56 : i32, tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) + ^bb1(%58: i32, %59: tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, %60: tensor<64x32xf16, #shared>, %61: tensor<32x64xf16, #shared1>, %62: tensor<64x32x!tt.ptr, #blocked>, %63: tensor<32x64x!tt.ptr, #blocked1>): // 2 preds: ^bb0, ^bb2 + %64 = arith.cmpi slt, %58, %57 : i32 + cf.cond_br %64, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %65 = tt.load %62 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked> + %66 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1> + %67 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> + %68 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> + %69 = tt.dot %67, %68, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> -> tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> + %70 = tt.addptr %62, %cst_0 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %71 = tt.addptr %63, %50 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %72 = triton_gpu.convert_layout %65 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared> + %73 = triton_gpu.convert_layout %66 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1> + %74 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%74, %69, %72, %73, %70, %71 : i32, tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) + ^bb3: // pred: ^bb1 + %75 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> + %76 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> + %77 = tt.dot %75, %76, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>, kWidth = 8}>> -> tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> + %78 = arith.truncf %77 : tensor<64x64xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> to tensor<64x64xf16, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>> + %79 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked1> + %80 = arith.muli %79, %27 : tensor<64x1xi32, #blocked1> + %81 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %82 = tt.addptr %81, %80 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %83 = tt.broadcast %82 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %84 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %85 = tt.addptr %83, %84 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %86 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked1> + %87 = "triton_gpu.cmpi"(%27, %86) <{predicate = 2 : i64}> : (tensor<64x1xi32, #blocked1>, tensor<64x1xi32, #blocked1>) -> tensor<64x1xi1, #blocked1> + %88 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked1> + %89 = "triton_gpu.cmpi"(%43, %88) <{predicate = 2 : i64}> : (tensor<1x64xi32, #blocked1>, tensor<1x64xi32, #blocked1>) -> tensor<1x64xi1, #blocked1> + %90 = tt.broadcast %87 : (tensor<64x1xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> + %91 = tt.broadcast %89 : (tensor<1x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> + %92 = arith.andi %90, %91 : tensor<64x64xi1, #blocked1> + %93 = triton_gpu.convert_layout %78 : (tensor<64x64xf16, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed = false}>>) -> tensor<64x64xf16, #blocked1> + tt.store %85, %93, %92 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf16, #blocked1> + tt.return + } +}