Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize matrix multiplication performance using tma multicast #3689

Open
1 of 7 tasks
rdspring1 opened this issue Jan 9, 2025 · 15 comments
Open
1 of 7 tasks

Optimize matrix multiplication performance using tma multicast #3689

rdspring1 opened this issue Jan 9, 2025 · 15 comments
Assignees
Labels

Comments

@rdspring1
Copy link
Collaborator

rdspring1 commented Jan 9, 2025

Overview

Multicast operand tiles to CTAs in a cluster to improve memory bandwidth.

Details

  • The multicast variant of TMA load allows copying a tile in global memory to the shared memory of multiple CTAs of a cluster.
  • For matrix multiplication, operand A and B tiles are shared among CTAs in a cluster.
  • It will improve L2 cache performance and reduce TMA transactions.

TODOs

  • Add cluster functions to runtime helper files. --- Add cuda wrapper for cluster ptx operations #3672
  • Add tma multicast functions to runtime helper files.
  • Maybe support uint64_t dtype for multicast_mask
  • Add kir::ClusterSync and kir::BlockRankInCluster expressions
  • Add multicast parallel types. e.g., MCIDx, MCDy, MCIDz
  • Synchronize CTAs in cluster using mbarrier::arrive(mbarrier, cta_id)
  • Implement multicast logic

Multicast Logic

If cluster_dims is set in fusion managed data, then apply tma multicast to operands.

  • Split operand tile dimension by CTAs in cluster dimension
  • Parallelize IterDomain with cluster parallel type
  • Replace standard tma load with multicast load using appropriate multicast mask
  • Add cluster_sync in insert_syncs pass.
  • Use distributed mbarrier arrive mbarrier::arrive(mbarrier, cta_id)

Pseudo-Code - Non-Warp Specialization.

void tma_load_kernel(...) { 
  __shared__ T smem_data[...];
  __shared__ uint64_t mbarrier;
  if (elect_sync()) {
    initialize_mbarrier(mbarrier, /* arrival count */ 1);
  }
  __syncthreads();
  cluster_sync();
  fence.sync.aligned;
 
  if (elect_sync()) {
    // Parallelize IterDomain with cluster dim
    for idx in range(number_ctas_in_cluster_dim) {
        set_barrier_transaction_bytes(mbarrier, local_cta_transaction_bytes);
        tma_multicast(mbarrier, tma_descriptor, gmem_addr, smem_addr, multicast_mask);
    }
  }
  __syncthreads();
  wait(mbarrier, phase);
  cluster_sync();
}

Pseudo-Code - Warp Specialization.

void tma_load_ws_kernel(...) { 
  __shared__ T smem_data[...];
  __shared__ uint64_t empty_mbarrier;
  __shared__ uint64_t full_mbarrier;

  if (elect_sync()) {
    initialize_mbarrier(empty_mbarrier, /*arrival_count=*/number_of_tma_transactions);
    initialize_mbarrier(full_mbarrier, 
                        /*arrival_count=*/number_of_compute_threads * number_of_CTAs_in_multicast);
  }
  __syncthreads();
  cluster_sync();
  fence.sync.aligned;
 
  if (tma_warp_group) {
      if (elect_sync()) {
          wait(empty_mbarrier, phase);

          // Parallelize IterDomain with cluster dim
         for idx in range(number_ctas_in_cluster_dim) {
             set_barrier_transaction_bytes(full_mbarrier, local_cta_transaction_bytes);
             tma_multicast(mbarrier, tma_descriptor, gmem_addr, smem_addr, multicast_mask);
        }
      }
  }

  if (compute_warp_group) {
    for all CTAs in multicast {
        arrive(empty_barrier, cta_id);
    }

    wait(full_barrier);

    // compute something

    for all CTAs in multicast {
        arrive(empty_barrier, cta_id);
    }
  }  
}

Cutlass Snippets

  • Each CTA participating in multicast divides original smem box.
  cute::array<uint32_t, 5> smem_box_shape  = {1,1,1,1,1};
  cute::array<uint32_t, 5> smem_box_stride = {1,1,1,1,1};
  // The smem box is simply given by the sizes of the modes in tma_gbasis
  for_each(make_seq<tma_dim>{}, [&](auto i) {
    smem_box_shape[i] *= size<i>(tma_gbasis);
  });
  // Finally, truncate the tma box by the num_multicast
  for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) {
    assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0);
    uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]);
    smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast);
    multicast = new_mult;
  }
  • Use EVICT_FIRST L2 cache hint for operand A; Multicast along N dimension for this operand A (M,K)
  • Use EVICT_LAST L2 cache hint for operand B; Multicast along M dimension for this operand B (N,K)

PTX Details

The optional modifier .multicast::cluster allows copying of data from global memory to shared memory of multiple CTAs in the cluster. Operand ctaMask specifies the destination CTAs in the cluster such that each bit position in the 16-bit ctaMask operand corresponds to the %ctaid of the destination CTA. The source data is multicast to the same CTA-relative offset as dstMem in the shared memory of each destination CTA. The mbarrier signal is also multicast to the same CTA-relative offset as mbar in the shared memory of the destination CTA.

  • Mulitcast is limited to the first 16 CTAs in the cluster.
  • The transaction bytes is the data loading to shared memory of local CTA.

References:

  1. https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk
  2. https://research.colfax-intl.com/tutorial-hopper-tma/
  3. https://research.colfax-intl.com/cutlass-tutorial-persistent-kernels-and-stream-k/
@rdspring1 rdspring1 self-assigned this Jan 9, 2025
@jacobhinkle
Copy link
Collaborator

Will the BIDx, BIDy, and BIDz parallel types be updated to reflect this? Meaning if I have a cluster dim of 2,1,1 and i parallelize a tv as [ CIDx{2}, BIDx{4} ] will we launch a grid of size 4 1 1 or 8 1 1? In order to keep the analogy to the thread and block dim parallel types, i think it should be 8.

@rdspring1
Copy link
Collaborator Author

__cluster_dims__ is independent of grid and block size.

cluster dims does not need to be a managed parameter.
It could also become a launch parameter and be inferred.

@jacobhinkle
Copy link
Collaborator

__cluster_dims__ is independent of grid and block size.

cluster dims does not need to be a managed parameter.
It could also become a launch parameter and be inferred.

But in order to split an axis and map one of the IDs as CID and the other as BID, they need to be truly hierarchical dont they?

@rdspring1
Copy link
Collaborator Author

@zasdfgbnm request for comments

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Jan 10, 2025

hierarchical

Image

This is maybe best described graphically as we've discussed previously about using ParallelType Exprs. The picture above shows a 1D version of the dimensions involved (could also include warp group and warp or device/node rank if we really wanted to...). I am wondering, since we already have BIDx and it's easy to translate at codegen, how will we (should we?) represent CTA rank in cluster? If we have both CIDx and BIDx, will we need to ensure that those are not both mapped within the same TensorView? All the splits must be divisible, which lets us infer grid parameters like CTAGridDimX if we only mapped CIDx and CTARankInClusterX, but they'll need to be consistent.

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Jan 10, 2025

I think there is a misunderstanding of how the cluster parallel type is to be used.

Yes, we need a hierarchy on the CTAs, but the parallel type is not applied here.

Cluster Hierarchy of C tiles

  • Given a set of C tiles for the grid: (BM, BN)
  • Split BM and BN by cluster factor ---> (RM, CM, RN, CN)
  • Reorder iterDomains so we have correct ordering of tiles in persistent GEMM ---> (RM, RN, CM, CN)
  • Merge all iterDomains together ---> (RM * RN * CM * CN)
  • Split by number of SMs ---> (RT, SMs)
  • Apply block parallelization ---> (RT, SMs (BIDx)

Cluster Parallel Type is not applied in the ordering of the tiles.

How Cluster Parallel Type is to be used?

TMA multicast requires each CTA to load a slice of the overall tile. We divide the TMA tile by the number of CTAs participating in the multicast. In NvFuser, we can split a tile dimension by number of CTAs in the cluster dimension. This will create a for-loop in the code launching multiple tma load operations. This for-loop that launches multiple tma operations is parallelized among the CTAs participating in the multicast. This is where the cluster parallel type is applied.

  • Given Operand A that is shared along rows of the cluster (CTA-M, CTA-K)
  • Split (CTA-M, some-number-of-ctas-along-cluster-x-dim, CTA-K)
  • Apply cluster parallelization (RCM, ParallelType::CIDx, CTA-K)

Technically, the multicast factor is some-number-of-ctas-along-cluster-x-dim < 16 < cluster-x-dim.

Cluster Parallel Type is parallelize the multicast operation across SMs in the cluster.

@rdspring1
Copy link
Collaborator Author

I guess you can rename it to multicast parallel type, which is highly correlated with cluster dimension.

For hopper, the limitation is some-number-of-ctas-along-cluster-x-dim < 16 < cluster-x-dim.

@zasdfgbnm
Copy link
Collaborator

  • Reorder iterDomains so we have correct ordering of tiles in persistent GEMM ---> (RM, RN, CM, CN)
  • Merge all iterDomains together ---> (RM * RN * CM * CN)
  • Split by number of SMs ---> (RT, SMs)

Does SMs always need to be a multiple of CM * CN? For example, assume SMs is 127, and CM = CN = 2. In this case, every 4 tile is a group. However, because 127 is not a multiple of 4, we will have "stream-k" like pattern, like below:

Image

That is, (RT=0, SM={124, 125, 126}) and (RT=1, SM=0) belong to the same group. Because they have different RT, it is impossible to multicast it.

@zasdfgbnm
Copy link
Collaborator

Intuitively, multicasting sounds to me is naturally represented as a broadcasting ID parallelized on "CTA rank in cluster". Not sure if this is aligned with what is proposed here.

@zasdfgbnm
Copy link
Collaborator

@rdspring1 Does this picture describes how you are thinking about ParallelType::Multicast should be used?

Image

@rdspring1
Copy link
Collaborator Author

Note: I added warp specialization pseudo-code, which requires supporting distributed mbarrier arrive.

Details

  • To prevent race conditions with tma multicast, we use clusterSync. This causes poor performance for warp specialization because it synchronizes all warp groups.
  • For warp specialization, we use a distributed mbarrier arrive, which allows CTAs in the cluster to signal arrival at mbarriers on different SMs.
  • For our current mbarrier circular buffering implementation, the empty_mbarrier must wait for all compute threads for all CTAs participating in the multicast.

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Jan 14, 2025

Does this picture describes how you are thinking about ParallelType::Multicast should be used?

The picture matches how I think the CTAs tiles should be traversed w.r.t. the __cluster_dims__, but ParallelType::Multicast isn't needed here.

Both the CTA traversal and tma multicast is bounded by the __cluster_dims__ but they do not have to be same factor.

ParallelType::Multicast is only required for the iterDomains in the tma multicast itself. e.g., The CTAs are working together to load sub-tiles of the original TMA load.

Note: The number of CTAs participating in the multicast is less than the maximum number of CTAs in the cluster for hopper.

@rdspring1
Copy link
Collaborator Author

Does SMs always need to be a multiple of CM * CN?

It is a concern.

These are two half-baked ideas.

  1. I think the multicast per mma operand may need to be predicated if CM or CN is not evenly divisible by cluster dims.
  2. I wonder if you can pad the CM * CN tiles so you can always multicast.

@zasdfgbnm
Copy link
Collaborator

I am trying to derive from first principle how multicasting should be done.

Consider operand A, whose N axis is broadcasting. The schedule of A would be something like below:

Image

In the above figure, the smem tv's logical domain is [M, K, bN], and loop domain is circled out red. For simplicity, further transformations on Mi, Ki, and bNi are omitted. Let's just assume Mi, Ki are parallelized on ParallelType::Bulk. Let's use i_rt, i_mi, i_bk, cidx, i_ki, rankx, and i_ni to represent loop variables, and use i_m, i_k, and i_n to represent the index of the logical domain. Clearly, i_m, i_k, and i_n are functions of i_rt, i_mi, i_bk, cidx, i_ki, rankx, and i_ni. Let's use notation like below to represent this function relationship:

i_m(rankx; i_rt, i_mi, i_bk, cidx, i_ki, rankx, i_ni)

The code we want to generate for A is something like below:

if (this_cta_should_issue_tma() && elect_sync()) {
  tma_multicast(address, mask());
}

and the question is, what is this this_cta_should_issue_tma and what is this mask?

Before answering this, let's define same(rankx1, rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) as:

same(rankx1, rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) :=
  i_m(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) ==  i_m(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) &&
  i_k(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) == i_k(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni)

That is, for the given i_rt, i_mi, i_bk, cidx, i_ki, and i_ni, the specified two CTAs rankx1 and rankx2 in the same CGA are loading the same logical item. Note that in the above definition of same, there is no requirement on the equality of i_n, because it is broadcasting, so different index value refers to the same logical item.

First, we note that, because i_k does not depend on rankx, the equality condition for i_k in the definition of same is always true. Therefore, we have

same(rankx1, rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) :=
  i_m(rankx1; i_rt, i_mi, i_bk, cidx, i_ki, i_ni) == i_m(rankx2; i_rt, i_mi, i_bk, cidx, i_ki, i_ni)

Note that i_m does not depend on i_bk, i_ki, and i_ni, we can omit these parameters in the definition of same and write it as:

same(rankx1, rankx2; i_rt, i_mi, cidx)

Because Mi are parallelized on Bulk, there is actually no loop generated for it. So we define same_bulk as:

same_bulk(rankx1, rankx2; i_rt, cidx) :=
  for all i_mi, we have same(rankx1, rankx2; i_rt, i_mi, cidx)

That is, same_bulk represents that for the given i_rt, and cidx, the specified two CTAs rankx1 and rankx2 in the same CGA are loading the same logical bulk. Let i_bm denote the index of BM, then we have i_m = i_bm * Mi + i_mi. Then same_bulk can be simplified as:

same_bulk(rankx1, rankx2; i_rt, cidx) := i_bm(rankx1; i_rt, cidx) == i_bm(rankx2; i_rt, cidx)

With this same_bulk, we can easily generate the predicate this_cta_should_issue_tma as:

this_cta_should_issue_tma(rank; i_rt, cidx) :=
  for all r < rank, we have !same_bulk(r, rank; i_rt, cidx)

to make it easier to generate code, we can write it into its equivalent form as below:

this_cta_should_issue_tma(rank; i_rt, cidx) :=
  for all r in CGA size, we have (r >= rank || !same_bulk(r, rank; i_rt, cidx))

We can also generate mask using same_bulk as:

mask(rank; i_rt, cidx) :=
  (same_bulk(0, rank; i_rt, cidx) ? 0b1 : 0) | (same_bulk(1, rank; i_rt, cidx) ? 0b10 : 0) | ...

Note that, because the above this_cta_should_issue_tma and mask is derived from first principle, it is always correct, regardless of divisibility of CM*CN and SMs. And because i_rt and cidx are parameters, different iteration and different cluster could be using different this_cta_should_issue_tma and mask.

Depending on the divisibility of CM*CN and SMs, this_cta_should_issue_tma and mask may or may not be able to be further mathematically simplified. For the case when further simplification is not possible, we are done. There is no "better" way compared to just generating the this_cta_should_issue_tma and mask as above. For this case, I have no idea if the performance it provide will be satisfactory, and my guess is no. So I believe we should focus on cases where divisibility could be used to simplify this_cta_should_issue_tma and mask.

(to be continued)

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Jan 15, 2025

(continuing)

It is important to note that, we can implement same_bulk as a "narrowing approximation". That is, it may consider some equal things as unequal, but it will never consider unequal things equal. For example, we can implement same_bulk as

same_bulk(...) := false

In this case, we will lose optimization opportunity by not multicasting at all, but we will never generate wrong code.

Let:

I(rank; i_rt, cidx) = i_rt x SMs + cidx * cluster_size + rankx

then we have:

i_bm(rankx; i_rt, cidx) = I/(CM*CN)/RN * CM + I%(CM*CN)/CN

From Theorem 2.12 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:

I%(CM*CN) = I%CN + (I/CN)%CM*CN

Therefore, from Theorem 2.15.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:

I%(CM*CN)/CN = (I%CN + (I/CN)%CM*CN)/CN = (I/CN)%CM

Also, from Theorem 2.11 in https://github.com/NVIDIA/Fuser/blob/main/doc/math/integer-division.md, we have:

 I/(CM*CN) = I/CN/CM

So, we can simplify i_bm as:

i_bm(rankx; i_rt, cidx) = I/CN/CM/RN * CM + I/CN%CM

Then

same_bulk(rankx1, rankx2; i_rt, cidx) :=
  i_bm(rankx1; i_rt, cidx) == i_bm(rankx2; i_rt, cidx)

=> (mathematical equivalence)
  I(rankx1; i_rt, cidx)/CN/CM/RN == I(rankx2; i_rt, cidx)/CN/CM/RN &&
  I(rankx1; i_rt, cidx)/CN%CM == I(rankx2; i_rt, cidx)/CN%CM

=> (narrowing approximate)
  I(rankx1; i_rt, cidx)/CN/CM == I(rankx2; i_rt, cidx)/CN/CM &&
  I(rankx1; i_rt, cidx)/CN%CM == I(rankx2; i_rt, cidx)/CN%CM

=> (mathematical equivalence)
  I(rankx1; i_rt, cidx)/CN == I(rankx2; i_rt, cidx)/CN

in short, we have:

same_bulk(rankx1, rankx2; i_rt, cidx) ~
  (i_rt x SMs + cidx * cluster_size + rankx1) / CN == (i_rt x SMs + cidx * cluster_size + rankx2) / CN

from the above equation, we see that:

  • If SMs is a multiple of CN, then same_bulk does not depend on i_rt due to Theorem 2.15.1.
  • If cluster_size is a multiple of CN, then same_bulk does not depend on cidx due to Theorem 2.15.1.

When both condition above are satisfied (I believe this is the most common in practice), we have

same_bulk(rankx1, rankx2; i_rt, cidx) ~ (rankx1 / CN == rankx2 / CN)

For this case, this_cta_should_issue_tma should be as simple as rank % CN == 0, and mask should be (2 ** CN - 1) << CN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants