-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize matrix multiplication performance using tma multicast #3689
Comments
Will the BIDx, BIDy, and BIDz parallel types be updated to reflect this? Meaning if I have a cluster dim of 2,1,1 and i parallelize a tv as [ CIDx{2}, BIDx{4} ] will we launch a grid of size 4 1 1 or 8 1 1? In order to keep the analogy to the thread and block dim parallel types, i think it should be 8. |
cluster dims does not need to be a managed parameter. |
But in order to split an axis and map one of the IDs as CID and the other as BID, they need to be truly hierarchical dont they? |
@zasdfgbnm request for comments |
This is maybe best described graphically as we've discussed previously about using ParallelType Exprs. The picture above shows a 1D version of the dimensions involved (could also include warp group and warp or device/node rank if we really wanted to...). I am wondering, since we already have BIDx and it's easy to translate at codegen, how will we (should we?) represent CTA rank in cluster? If we have both CIDx and BIDx, will we need to ensure that those are not both mapped within the same TensorView? All the splits must be divisible, which lets us infer grid parameters like CTAGridDimX if we only mapped CIDx and CTARankInClusterX, but they'll need to be consistent. |
I think there is a misunderstanding of how the cluster parallel type is to be used. Yes, we need a hierarchy on the CTAs, but the parallel type is not applied here. Cluster Hierarchy of C tiles
Cluster Parallel Type is not applied in the ordering of the tiles.How Cluster Parallel Type is to be used?TMA multicast requires each CTA to load a slice of the overall tile. We divide the TMA tile by the number of CTAs participating in the multicast. In NvFuser, we can split a tile dimension by number of CTAs in the cluster dimension. This will create a for-loop in the code launching multiple tma load operations. This for-loop that launches multiple tma operations is parallelized among the CTAs participating in the multicast. This is where the cluster parallel type is applied.
Technically, the multicast factor is Cluster Parallel Type is parallelize the multicast operation across SMs in the cluster. |
I guess you can rename it to multicast parallel type, which is highly correlated with cluster dimension. For hopper, the limitation is |
Intuitively, multicasting sounds to me is naturally represented as a broadcasting ID parallelized on "CTA rank in cluster". Not sure if this is aligned with what is proposed here. |
@rdspring1 Does this picture describes how you are thinking about |
Note: I added warp specialization pseudo-code, which requires supporting distributed mbarrier arrive. Details
|
The picture matches how I think the CTAs tiles should be traversed w.r.t. the Both the CTA traversal and tma multicast is bounded by the
Note: The number of CTAs participating in the multicast is less than the maximum number of CTAs in the cluster for hopper. |
It is a concern. These are two half-baked ideas.
|
I am trying to derive from first principle how multicasting should be done. Consider operand A, whose In the above figure, the smem tv's logical domain is
The code we want to generate for A is something like below: if (this_cta_should_issue_tma() && elect_sync()) {
tma_multicast(address, mask());
} and the question is, what is this Before answering this, let's define same(rankx1, rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) :=
i_m(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) == i_m(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) &&
i_k(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) == i_k(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) That is, for the given First, we note that, because same(rankx1, rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) :=
i_m(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) == i_m(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) Note that same(rankx1, rankx2; i_rt, i_mi, cidx) Because same_bulk(rankx1, rankx2; i_rt, cidx) :=
for all i_mi, we have same(rankx1, rankx2; i_rt, i_mi, cidx) That is, same_bulk(rankx1, rankx2; i_rt, cidx) := i_bm(rankx1; i_rt, cidx) == i_bm(rankx2; i_rt, cidx) With this this_cta_should_issue_tma(rank; i_rt, cidx) :=
for all r < rank, we have !same_bulk(r, rank; i_rt, cidx) to make it easier to generate code, we can write it into its equivalent form as below: this_cta_should_issue_tma(rank; i_rt, cidx) :=
for all r in CGA size, we have (r >= rank || !same_bulk(r, rank; i_rt, cidx)) We can also generate mask(rank; i_rt, cidx) :=
(same_bulk(0, rank; i_rt, cidx) ? 0b1 : 0) | (same_bulk(1, rank; i_rt, cidx) ? 0b10 : 0) | ... Note that, because the above Depending on the divisibility of (to be continued) |
(continuing) It is important to note that, we can implement same_bulk(...) := false In this case, we will lose optimization opportunity by not multicasting at all, but we will never generate wrong code. Let: I(rank; i_rt, cidx) = i_rt x SMs + cidx * cluster_size + rankx then we have: i_bm(rankx; i_rt, cidx) = I/(CM*CN)/RN * CM + I%(CM*CN)/CN From Theorem 2.12 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:
Therefore, from Theorem 2.15.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:
Also, from Theorem 2.11 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:
So, we can simplify
Then
in short, we have:
from the above equation, we see that:
When both condition above are satisfied (I believe this is the most common in practice), we have
For this case, |
Overview
Multicast operand tiles to CTAs in a cluster to improve memory bandwidth.
Details
TODOs
uint64_t
dtype formulticast_mask
kir::ClusterSync
andkir::BlockRankInCluster
expressionsMCIDx
,MCDy
,MCIDz
mbarrier::arrive(mbarrier, cta_id)
Multicast Logic
If cluster_dims is set in fusion managed data, then apply tma multicast to operands.
cluster_sync
ininsert_syncs
pass.mbarrier::arrive(mbarrier, cta_id)
Pseudo-Code - Non-Warp Specialization.
Pseudo-Code - Warp Specialization.
Cutlass Snippets
EVICT_FIRST
L2 cache hint for operand A; Multicast along N dimension for this operand A (M,K)EVICT_LAST
L2 cache hint for operand B; Multicast along M dimension for this operand B (N,K)PTX Details
References:
The text was updated successfully, but these errors were encountered: