-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUTLASS Grouped GEMM #6
base: main
Are you sure you want to change the base?
Conversation
NVIDIA/cutlass#1286 should be available on H100 now as well |
Hi! Thanks for the PR! We have users who currently rely on the cuBLAS path for Hopper, which this PR deletes, I think.
Since this is now available, it'd be great to support for SM90! It looks like it requires a very new version of CUDA so perhaps it would be best to keep around the simple cuBLAS implementation to fallback to if we can't support CUTLASS grouped GEMM? |
is there any critical reason for why grouped gemm is hardcoded to use BFloat16? or would a string replace of bfloat16 with float16 just work? |
There is no reason why we only support BFloat16. I implemented only bfloat because that was what our user who needed this feature uses. It would be relatively easy to template our helpers and dispatch based on input tensor type. |
@tgale96 For context: I'm working on a branch that removes the CPU<->GPU sync for I also stumbled upon a nasty CUTLASS bug when one of the elements in |
Hey! It would be great to have a full CUTLASS path but I do not personally have the cycles for it at the moment. Contributions would be very welcome, and I'd be happy to provide any guidance that is necessary! |
Cool! I opened #14 as a starting point. Any guidance would be much appreciated! :) |
Use CUTLASS for grouped GEMM (both no transposition, trans_a, trans_b).
~20% speedup on A100