Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUTLASS Grouped GEMM #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

imoneoi
Copy link

@imoneoi imoneoi commented Dec 27, 2023

Use CUTLASS for grouped GEMM (both no transposition, trans_a, trans_b).

~20% speedup on A100

@152334H
Copy link

152334H commented Jan 1, 2024

NVIDIA/cutlass#1286 should be available on H100 now as well

@tgale96
Copy link
Owner

tgale96 commented Jan 2, 2024

Hi! Thanks for the PR!

We have users who currently rely on the cuBLAS path for Hopper, which this PR deletes, I think.

NVIDIA/cutlass#1286 should be available on H100 now as well

Since this is now available, it'd be great to support for SM90! It looks like it requires a very new version of CUDA so perhaps it would be best to keep around the simple cuBLAS implementation to fallback to if we can't support CUTLASS grouped GEMM?

@152334H
Copy link

152334H commented Jan 7, 2024

is there any critical reason for why grouped gemm is hardcoded to use BFloat16? or would a string replace of bfloat16 with float16 just work?

@tgale96
Copy link
Owner

tgale96 commented Jan 8, 2024

There is no reason why we only support BFloat16. I implemented only bfloat because that was what our user who needed this feature uses. It would be relatively easy to template our helpers and dispatch based on input tensor type.

@dfyz
Copy link
Contributor

dfyz commented Jun 19, 2024

@tgale96
Hi! Are there any plans for this branch to eventually be merged? I'm not sure what the exact CUDA requirements for the newer CUTLASS versions are, but it might indeed be a good idea to restore the simple cuBLAS fallback for H100 to make this PR only work for A100 (so far).

For context: I'm working on a branch that removes the CPU<->GPU sync for batch_sizes, and having both forward and backward passes use CUTLASS is a prerequisite for that (since using cuBLAS kind of implies that you have to know the batch sizes on the host).

I also stumbled upon a nasty CUTLASS bug when one of the elements in batch_sizes is 0. I'm not sure when/if the fix is going to be merged upstream, but it might be a good idea to backport the relevant changes to the CUTLASS version used in this repo.

@tgale96
Copy link
Owner

tgale96 commented Jun 24, 2024

Hey! It would be great to have a full CUTLASS path but I do not personally have the cycles for it at the moment. Contributions would be very welcome, and I'd be happy to provide any guidance that is necessary!

@dfyz
Copy link
Contributor

dfyz commented Jun 24, 2024

Cool! I opened #14 as a starting point. Any guidance would be much appreciated! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants