-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scheduling a fast CUDA reduction #8336
Comments
Related to #8100 |
NVIDIA has some sample code where they demonstrate a "prefix sum" or "prefix scan" algorithm, IIRC. It sums the numbers of an array in I have not read this page in-depth right now, but I think that getting all of the nice properties that the NVIDIA engineers have put into these algorithms is hard to achieve with Halide. Writing an optimized reducing in Halide for CPU is relatively trivial. Doing the same for GPU is impossible right now, I think. Unless I'm overlooking tricks with |
For summations logarithmic algorithms don't really matter. Once you sum down a small constant number of times, it doesn't matter how you schedule the rest - it's almost no work. For logarithmic algorithms we usually just use a fixed "sufficient" number of levels in the hierarchy, and disable or skip some of them based on the size of the input using things like RDom::where. |
An example schedule for a reduction would be super useful. The best I was able to write was still 3.5x slower than the baseline. For a reduction (sum, prod, max, etc), a constant number of levels should be fine (the Triton example above uses a constant number of levels, since it only uses a shared memory algorithm for the final fixed-size RBLOCK). Scan operations (cumsum, cumprod, etc) may be harder, though are relatively uncommon when compared reductions. |
Here is my
|
Here is print_nest_loop before/after applying scheduling: before:
after:
|
I am trying to get Halide to replicate the performance on Triton and having some trouble. A lot of the issues seems related to the performance of reductions, so I wanted to start out with a very simple reduction, here I am compiling:
where
x = torch.randn([8192, 8192], device="cuda")
. We call this an inner reduction, since we are accumulating the contiguous dimension, which is the most common type. (There are also outer reductions and mixed ones where we fuse multiple layouts together.)On my local RTX 3090, Halide is much slower than Triton:
Our generated Halide code for this is:
The scheduler generated by Anderson2021 is:
And Li2018 generates:
The (~4x faster) Triton code for the same function is:
Where
XBLOCK=1
andRBLOCK=2048
(we use these for all inner reductions, but have different ones for other types)This Triton code:
Has a
XBLOCK
byRBLOCK
accumulator block which iterates over the reduction dimension summing up blocks of elements at a time in a data-parallel way. Triton automatically maps this 2D block to GPU threads, which may involve some unrolling.Calls
tl.sum()
which does a shared memory parallel tree reduction. Something like:Forming a tree until the first thread has the result. The details of this shared memory parallel reduction are figured out by Triton automatically. Though the algorithm is pretty boilerplate.
I was trying to write a Halide schedule for this, but got stuck pretty quickly. Some scheduling questions:
sum.update(0)
function the autoschedulers use, since it is defined insidehl.sum()
(which returns an Expr not a Func). How do I do that?cc @abadams @alexreinking
The text was updated successfully, but these errors were encountered: