Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
scxiao committed Feb 27, 2025
1 parent a43f79c commit 3428124
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 45 deletions.
8 changes: 0 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,8 +1715,6 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const

val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
dst = torch.zeros((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
print(f"val_shape = {val.shape}")

dst_ref = dst.clone()

cnt = 0
Expand All @@ -1726,14 +1724,8 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
dst_ref[i][elem] += val[i][j]
cnt += 1

print(f"val = {val}")
print(f"idx = {idx}")

kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)

print(f"dst_ref = {dst_ref}")
print(f"dst = {dst}")

np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)


Expand Down
89 changes: 52 additions & 37 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,36 +1212,50 @@ Value genPrefixSum(PatternRewriter &rewriter, Value v0) {

Value v1 = v0;
// v_add_f32 v1, v0, v0 row_shr:1 bound_ctrl:0
Value tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v0, 0x111, 0xF, 0xF, false).getResult();
Value tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v0, 0x111, 0xF,
0xF, false)
.getResult();
v1 = b.add(v1, tmp);
// v_add_f32 v1, v0, v1 row_shr:2 bound_ctrl:0
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v0, 0x112, 0xF, 0xF, false).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v0, 0x112, 0xF, 0xF,
false)
.getResult();
v1 = b.add(v1, tmp);
// v_add_f32 v1, v0, v1 row_shr:3 bound_ctrl:0
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v0, 0x113, 0xF, 0xF, false).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v0, 0x113, 0xF, 0xF,
false)
.getResult();
v1 = b.add(v1, tmp);

// v_add_f32 v1, v1, v1 row_shr:4 bank_mask:0xe
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v1, 0x114, 0xF, 0xE, true).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v1, 0x114, 0xF, 0xE,
true)
.getResult();
v1 = b.add(v1, tmp);

// v_add_f32 v1, v1, v1 row_shr:8 bank_mask:0xc
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v1, 0x118, 0xF, 0xC, true).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v1, 0x118, 0xF, 0xC,
true)
.getResult();
v1 = b.add(v1, tmp);

// v_add_f32 v1, v1, v1 row_bcast:15 row_mask:0xa
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v1, 0x142, 0xA, 0xF, true).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v1, 0x142, 0xA, 0xF,
true)
.getResult();
v1 = b.add(v1, tmp);

// v_add_f32 v1, v1, v1 row_bcast:31 row_mask:0xc
tmp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, v1, 0x143, 0xC, 0xF, true).getResult();
tmp = rewriter
.create<ROCDL::DPPUpdateOp>(loc, i32_ty, old, v1, 0x143, 0xC, 0xF,
true)
.getResult();
v1 = b.add(v1, tmp);

return v1;
Expand Down Expand Up @@ -1465,19 +1479,21 @@ struct AtomicRMWOpConversion
Value permuteOffset = genPrefixSum(rewriter, maskI32);
permuteOffset = b.mul(permuteOffset, maskI32);
int waveSize = 64;
permuteOffset = b.select(b.icmp_eq(permuteOffset, b.i32_val(0)), b.i32_val(waveSize), permuteOffset);
permuteOffset = b.select(b.icmp_eq(permuteOffset, b.i32_val(0)),
b.i32_val(waveSize), permuteOffset);
permuteOffset = b.sub(permuteOffset, b.i32_val(1));
permuteOffset = b.mul(permuteOffset, b.i32_val(4));
operand = genI32TiledOp(rewriter, genPermute, operand, permuteOffset);
Value castedAddr = b.ptrtoint(i64_ty, rmwPtr);
castedAddr = genI32TiledOp(rewriter, genPermute, castedAddr, permuteOffset);
castedAddr =
genI32TiledOp(rewriter, genPermute, castedAddr, permuteOffset);
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);

// update mask
Value maskFlag = targetInfo.ballot(rewriter, loc, i64_ty, rmwMask);
Value numActiveLanes =
b.trunc(i32_ty, generatePopcount64(rewriter, maskFlag));

Value laneID = b.urem(tid, b.i32_val(waveSize));
rmwMask = b.icmp_ult(laneID, numActiveLanes);
}
Expand Down Expand Up @@ -1614,10 +1630,10 @@ struct AtomicRMWOpConversion
rmwPtr = b.ptrtoint(i64_ty, rmwPtr);

auto *curBlock = rewriter.getInsertionBlock();
auto *leaderBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
leaderBlock->addArgument(i64_ty, loc);
leaderBlock->addArgument(operandElemType, loc);
auto *beforeLoop = rewriter.createBlock(
auto *atomicBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
atomicBlock->addArgument(i64_ty, loc);
atomicBlock->addArgument(operandElemType, loc);
auto *initLoop = rewriter.createBlock(
curBlock->getParent(), std::next(Region::iterator(curBlock)));

rewriter.setInsertionPointToEnd(curBlock);
Expand All @@ -1630,24 +1646,26 @@ struct AtomicRMWOpConversion
Value neighbourFlag = targetInfo.ballot(rewriter, loc, i64_ty, isNeighbour);
Value numNeighbours =
b.trunc(i32_ty, generatePopcount64(rewriter, neighbourFlag));
// heuristic, do optimization only if # of neighbours is less than 32,
// [TODO], will calculate actual # of difference addresses
Value skipOpt = b.icmp_ult(numNeighbours, b.i32_val(32));
// Heuristic that atomic_add is optimizated only if the number of
// neighbouring addresses in a wave is less than 32.
// TODO: Calculate actual number of difference addresses
// in a wave.
Value optAtomic = b.icmp_ult(numNeighbours, b.i32_val(32));

rewriter.create<LLVM::CondBrOp>(loc, skipOpt, beforeLoop, leaderBlock,
rewriter.create<LLVM::CondBrOp>(loc, optAtomic, initLoop, atomicBlock,
ValueRange({rmwPtr, operand}));
rewriter.setInsertionPointToEnd(beforeLoop);
rewriter.setInsertionPointToEnd(initLoop);

auto *afterLoopBlock = beforeLoop->splitBlock(rewriter.getInsertionPoint());
auto *afterLoopBlock = initLoop->splitBlock(rewriter.getInsertionPoint());
afterLoopBlock->addArgument(i32_ty, loc); // idx
afterLoopBlock->addArgument(i32_ty, loc); // cnt
afterLoopBlock->addArgument(int_ty(1), loc); // isLeader

auto *loopBody = rewriter.createBlock(
beforeLoop->getParent(), std::next(Region::iterator(beforeLoop)));
initLoop->getParent(), std::next(Region::iterator(initLoop)));
loopBody->addArgument(i32_ty, loc);

rewriter.setInsertionPointToEnd(beforeLoop);
rewriter.setInsertionPointToEnd(initLoop);
rewriter.create<LLVM::BrOp>(loc, b.i32_val(0), loopBody);

// Greed search of same addr within wavefront. Also collect auxiliary
Expand Down Expand Up @@ -1774,18 +1792,15 @@ struct AtomicRMWOpConversion
Value leaderCond = leaderRes;
Value defaultRes = b.undef(operandElemType);
rewriter.create<LLVM::CondBrOp>(
loc, leaderCond, leaderBlock,
loc, leaderCond, atomicBlock,
ValueRange({rmwPtr, afterRedBlock->getArgument(0)}), endBlock,
ValueRange({defaultRes}));
rewriter.setInsertionPointToEnd(leaderBlock);
rewriter.setInsertionPointToEnd(atomicBlock);
// Utilize global atomic only by leader threads
Value addr = leaderBlock->getArgument(0);
Value addr = atomicBlock->getArgument(0);
Value atomAddr = b.inttoptr(origPtrType, addr);
Value atom = rewriter
.create<LLVM::AtomicRMWOp>(loc, opKind, atomAddr,
leaderBlock->getArgument(1),
memOrdering, scope)
.getResult();
Value atom = rewriter.create<LLVM::AtomicRMWOp>(
loc, opKind, atomAddr, atomicBlock->getArgument(1), memOrdering, scope);
rewriter.create<LLVM::BrOp>(loc, atom, endBlock);
rewriter.setInsertionPointToStart(endBlock);

Expand Down

0 comments on commit 3428124

Please sign in to comment.