Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Cute docs need a concrete example using tensor cores #2063

Closed
capybara-club opened this issue Jan 25, 2025 · 6 comments
Closed

[QST] Cute docs need a concrete example using tensor cores #2063

capybara-club opened this issue Jan 25, 2025 · 6 comments

Comments

@capybara-club
Copy link

I'm trying to learn how to use cute and it's surprising that even the sgemm_sm80.cu example defaults to universalFMA and gives a much lower than possible GFlops for the mma. Can you update a version of this file that even just has commented out TiledMMA versions for different tensorops? I feel like that would be worth more than many many lines of documentation. From:


  TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
                                    Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
                                    Layout<Shape< _4,_1>>{});// Val layout  4x1 m-major
  TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
                                    Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
                                    Layout<Shape< _4,_1>>{});// Val layout  4x1 n-major

  TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
                                 Layout<Shape<_16,_16,_1>>{});  // 16x16x1 TiledMMA 

Having additional mmaC's there commented out would be fantastic, especially if they lined up with the copy tiling.

In the meantime could somebody help me with the edits I need to make to turn the sgemm_sm80.cu to use the TF32 tensor cores for sm80?

@ccecka
Copy link

ccecka commented Jan 29, 2025

Agreed, I have an open MR to update a lot of these CuTe examples but it's been delayed for various reasons.

In the meantime, here's the current state within that update MR of sgemm_sm80.cu with Tensor Core:
sgemm_sm80.txt

@capybara-club
Copy link
Author

Thank you! Happy to hear it's in the works. Being able to look at different data types and transposes and see how the tiling and swizzle changes will be really helpful especially if the versions get close to hardware limit. Following the code from the cutlass profiler and seeing whats happening in cute eventually is pretty hard to track.

Thanks for your example! What should i replace the jetfire:: calls with in the interim?

@ccecka
Copy link

ccecka commented Jan 30, 2025

Ah yes, I've updated the code above to remove those for you, thanks.

@capybara-club
Copy link
Author

Thank you, this is so helpful!

It gets really close to the 16x8x8 cutlass one (0.6 ms vs 0.56) for m=n=k=4096 again, really helpful.

I don't want to wear out my welcome, but the cutlass version has a cta size of (256, 128, 32) where this is (128, 128, 64).

rguments: --gemm_kind=universal --m=4096 --n=4096 --k=4096 --A=f16:row --B=f16:column --C=f16:column --D=f16:column \ --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 --batch_count=1 --raster_order=heuristic \ --runtime_input_datatype_a=invalid --runtime_input_datatype_b=invalid --use_pdl=false --swizzle_size=1 \ --op_class=tensorop --accum=f16 --cta_m=256 --cta_n=128 --cta_k=32 --cluster_m=1 --cluster_n=1 --cluster_k=1 \ --cluster_m_fallback=0 --cluster_n_fallback=0 --cluster_k_fallback=0 --stages=2 --warps_m=4 --warps_n=2 \ --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 --min_cc=75 --max_cc=1024

To help me understand something would you mind showing me how the code changes if I drive the cta size m up from 128 to 256 and drive the cta size k down from 64 to 32 in this example? Assuming they're just simple number changes in the gemm_tn function. If it's more involved than that, please don't bother.

@ccecka
Copy link

ccecka commented Jan 30, 2025

Then you get into some layout engineering, here's another thread we we walk through some of that:
#1953

An easy way to get 256x128x32 is to cut down the vectorization and change the layouts:

  // Define CTA tile sizes (static)
  auto bM = Int<256>{};
  auto bN = Int<128>{};
  auto bK = Int< 32>{};
  auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
  auto bP = Int<3>{};  // Pipeline

  // Define the smem layouts (static)
  // Swizzles for LDSM and 64b k-major loads
  auto swizzle_atom = composition(Swizzle<3,3,3>{},
                                  Layout<Shape <_16,Shape <_8,  _4>>,
                                         Stride< _8,Stride<_1,_128>>>{});

  auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));

  // Define the thread layouts (static)
  TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint64_t>, cute::half_t>{},
                                    Layout<Shape<_16,_8>,Stride<_8,_1>>{},  // Thr layout 16x8 k-major
                                    Layout<Shape< _1,_4>>{});               // Val layout  1x4 k-major

Untested -- maybe with the swizzle, maybe a different swizzles, etc. You'll want to balance GMEM vectorization and cache line utilization with SMEM bank conflicts and optimize the data layout sA for the read and write stages and the copyA operation for access patterns.

@capybara-club
Copy link
Author

Thank you very much. Seeing the numbers change and doing pdflatex in the layouts between the two versions and comparing them gives me a good amount of study material.

@hwu36 hwu36 closed this as completed Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants