Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduling a fast CUDA reduction #8336

Open
jansel opened this issue Jul 3, 2024 · 6 comments
Open

Scheduling a fast CUDA reduction #8336

jansel opened this issue Jul 3, 2024 · 6 comments
Assignees

Comments

@jansel
Copy link
Contributor

jansel commented Jul 3, 2024

I am trying to get Halide to replicate the performance on Triton and having some trouble. A lot of the issues seems related to the performance of reductions, so I wanted to start out with a very simple reduction, here I am compiling:

torch.sum(x, -1)

where x = torch.randn([8192, 8192], device="cuda"). We call this an inner reduction, since we are accumulating the contiguous dimension, which is the most common type. (There are also outer reductions and mixed ones where we fuse multiple layouts together.)

On my local RTX 3090, Halide is much slower than Triton:

  • 0.3137 ms with the Triton backend
  • 1.3972 ms with Halide + Anderson2021 autoscheduler
  • 1.3562 ms with Halide + Li2018 autoscheduler

Our generated Halide code for this is:

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 2)
    ks0 = hl.InputScalar(hl.Int(32))
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        ks0 = g.ks0
        out_ptr0 = g.out_ptr0
        h0 = hl.Var('h0')
        h1 = hl.Var('h1')
        rdom = hl.RDom([hl.Range(0, ks0)])
        hr0 = rdom[0]
        tmp0 = hl.Func('tmp0')
        tmp0[h0, h1] = in_ptr0[h0, h1,]
        tmp1 = hl.Func('tmp1')
        tmp1[h1] = hl.sum(rdom, tmp0[hr0, h1])
        out_ptr0[h1,] = hl.cast(hl.Float(32), tmp1[h1])

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(1).set_min(0)
        in_ptr0.set_estimates([hl.Range(0, 8192), hl.Range(0, 8192)])
        ks0.set_estimate(8192)
        out_ptr0.set_estimates([hl.Range(0, 8192)])

The scheduler generated by Anderson2021 is:

inline void apply_schedule_halide_kernel(
    ::Halide::Pipeline pipeline,
    ::Halide::Target target
) {
    using ::Halide::Func;
    using ::Halide::MemoryType;
    using ::Halide::RVar;
    using ::Halide::TailStrategy;
    using ::Halide::Var;
    auto pipeline = get_pipeline();
    Func out_ptr0 = pipeline.get_func(4);
    Func tmp1 = pipeline.get_func(3);
    Func sum = pipeline.get_func(2);
    Func tmp0 = pipeline.get_func(1);
    Var h1(out_ptr0.get_schedule().dims()[0].var);
    Var h1i("h1i");
    Var h1ii("h1ii");
    RVar r8_x(sum.update(0).get_schedule().dims()[0].var);
    Var h1i_serial_outer("h1i_serial_outer");
    out_ptr0
        .split(h1, h1, h1i, 8, TailStrategy::ShiftInwards)
        .split(h1i, h1i, h1ii, 2, TailStrategy::ShiftInwards)
        .unroll(h1ii)
        .compute_root()
        .reorder(h1ii, h1i, h1)
        .gpu_blocks(h1)
        .split(h1i, h1i_serial_outer, h1i, 4, TailStrategy::GuardWithIf)
        .gpu_threads(h1i);
    sum.update(0)
        .unroll(h1)
        .reorder(h1, r8_x);
    sum
        .unroll(h1)
        .compute_at(out_ptr0, h1i)
        .store_at(out_ptr0, h1)
        .reorder(h1);

}

And Li2018 generates:

inline void apply_schedule_halide_kernel(
    ::Halide::Pipeline pipeline,
    ::Halide::Target target
) {
    using ::Halide::Func;
    using ::Halide::MemoryType;
    using ::Halide::RVar;
    using ::Halide::TailStrategy;
    using ::Halide::Var;
    out_ptr0.compute_root()
        .split(h1,v2,v3,64,ShiftInwards)
        .reorder(v3,v2)
        .gpu_blocks(v2)
        .gpu_threads(v3)
    ;
    sum.compute_root()
        .split(h1,v4,v5,64,ShiftInwards)
        .reorder(v5,v4)
        .gpu_blocks(v4)
        .gpu_threads(v5)
    ;
    sum.update(0)
        .split(h1,v6,v7,64,GuardWithIf)
        .reorder(r8$x,v7,v6)
        .gpu_blocks(v6)
        .gpu_threads(v7)
    ;

}

The (~4x faster) Triton code for the same function is:

@triton.jit
def triton_(in_ptr0, out_ptr0, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (ks0*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
        tmp3 = _tmp2 + tmp1
        _tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)
    tmp2 = tl.sum(_tmp2, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp2, xmask)

Where XBLOCK=1 and RBLOCK=2048 (we use these for all inner reductions, but have different ones for other types)

This Triton code:

  1. Has a XBLOCK by RBLOCK accumulator block which iterates over the reduction dimension summing up blocks of elements at a time in a data-parallel way. Triton automatically maps this 2D block to GPU threads, which may involve some unrolling.

  2. Calls tl.sum() which does a shared memory parallel tree reduction. Something like:

    barrier 
    if (thread_id%K1)==0: accumulate from neighbors
    barrier
    if (thread_id%K2)==0: accumulate from neighbors
    ...

Forming a tree until the first thread has the result. The details of this shared memory parallel reduction are figured out by Triton automatically. Though the algorithm is pretty boilerplate.

  1. The first thread writes out the final result to main memory.

I was trying to write a Halide schedule for this, but got stuck pretty quickly. Some scheduling questions:

  • Is there a way to express a shared memory tree reduction in Halide? This is pretty import for perf here since you don't want to go through main memory to sum things up.
  • I couldn't figure out how to do GPU tiling in the RDom() dimension, is that possible?
  • Via the Python API, I couldn't figure out how to get a reference to the sum.update(0) function the autoschedulers use, since it is defined inside hl.sum() (which returns an Expr not a Func). How do I do that?

cc @abadams @alexreinking

@abadams
Copy link
Member

abadams commented Jul 3, 2024

  • Yes, a reduction tree is done with one rfactor per level in the hierarchy. The autoschedulers don't know how to do this. Li can apply a single rfactor, and Anderson doesn't know about rfactor at all. I'll try to write an example of a reasonably-scheduled summation.
  • No, because RVars can't be parallelized without atomic(), but atomic() is not the best tool for this. rfactor will convert them to Vars.
  • Use += instead of hl.sum if you want to schedule the update def. There's also a variant of sum that takes an as-yet undefined Func as the last arg and gives that func a pure and update definition that does the summation.

Related to #8100

@mcourteaux
Copy link
Contributor

NVIDIA has some sample code where they demonstrate a "prefix sum" or "prefix scan" algorithm, IIRC. It sums the numbers of an array in log(n) time. Something like outlined here https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda

I have not read this page in-depth right now, but I think that getting all of the nice properties that the NVIDIA engineers have put into these algorithms is hard to achieve with Halide. rfactor() is just not the tool for this; as you'd need a variable number of rfactors, depending on your input size. I'd be cool if you could call out to a couple of optimized routines such as this prefix-scan, right from the Halide pipeline code.

Writing an optimized reducing in Halide for CPU is relatively trivial. Doing the same for GPU is impossible right now, I think. Unless I'm overlooking tricks with RDoms where you manage to mimic this logarithmic algorithm?

@abadams
Copy link
Member

abadams commented Jul 17, 2024

For summations logarithmic algorithms don't really matter. Once you sum down a small constant number of times, it doesn't matter how you schedule the rest - it's almost no work.

For logarithmic algorithms we usually just use a fixed "sufficient" number of levels in the hierarchy, and disable or skip some of them based on the size of the input using things like RDom::where.

@jansel
Copy link
Contributor Author

jansel commented Jul 17, 2024

An example schedule for a reduction would be super useful. The best I was able to write was still 3.5x slower than the baseline.

For a reduction (sum, prod, max, etc), a constant number of levels should be fine (the Triton example above uses a constant number of levels, since it only uses a shared memory algorithm for the final fixed-size RBLOCK).

Scan operations (cumsum, cumprod, etc) may be harder, though are relatively uncommon when compared reductions.

@iitaku
Copy link

iitaku commented Aug 9, 2024

Here is my rfactor scheduling for the parallel reduction on GPU. I tested on the CUDA backend and performance is improved on my end. Hopefully it is usefull for you.

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 2)
    ks0 = hl.InputScalar(hl.Int(32))
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        ks0 = g.ks0
        out_ptr0 = g.out_ptr0
        h0 = hl.Var('h0')
        h1 = hl.Var('h1')
        rdom = hl.RDom([hl.Range(0, ks0)])
        hr0 = rdom[0]
        tmp0 = hl.Func('tmp0')
        tmp0[h0, h1] = in_ptr0[h0, h1,]
        tmp1 = hl.Func('tmp1')
        tmp1[h1] = 0.0
        tmp1[h1] += tmp0[hr0, h1]
        out_ptr0[h1,] = hl.cast(hl.Float(32), tmp1[h1])

        # Scheduling for parallel reduction 
        ro = hl.RVar('ro')
        ri = hl.RVar('ri')
        tidx = hl.Var('tidx')
        thread_num = 256
        
        intm = tmp1.update().split(rdom, ro, ri, thread_num).reorder(ro, ri, h1).rfactor(ri, tidx)
        intm.compute_at(tmp1, h1).gpu_threads(tidx).update(0).gpu_threads(tidx)

        ro, ri = ri, ro
        
        for i in range(int(math.log2(thread_num))):
            intm = tmp1.update().split(ro, ro, ri, 2).unroll(ri).rfactor(ro, tidx)
            intm.compute_at(tmp1, h1).gpu_threads(tidx).update(0).gpu_threads(tidx)
        
        out_ptr0.compute_root().gpu_blocks(h1)
        out_ptr0.print_loop_nest()

@iitaku
Copy link

iitaku commented Aug 9, 2024

Here is print_nest_loop before/after applying scheduling:

before:

produce out_ptr0:
  for h1:
    produce tmp1:
      tmp1(...) = ...
      for r8:
        tmp1(...) = ...
    consume tmp1:
      out_ptr0(...) = ...

after:

produce out_ptr0:
  gpu_block h1<Default_GPU>:
    produce tmp1:
      tmp1(...) = ...
      produce tmp1_intm:
        gpu_thread tidx in [0, 255]<Default_GPU>:
          tmp1_intm(...) = ...
        gpu_thread tidx in [0, 255]<Default_GPU>:
          for r8.ro:
            tmp1_intm(...) = ...
      consume tmp1_intm:
        produce tmp1_intm:
          gpu_thread tidx in [0, 127]<Default_GPU>:
            tmp1_intm(...) = ...
          gpu_thread tidx in [0, 127]<Default_GPU>:
            unrolled r8.ri.ro in [0, 1]:
              tmp1_intm(...) = ...
        consume tmp1_intm:
          produce tmp1_intm:
            gpu_thread tidx in [0, 63]<Default_GPU>:
              tmp1_intm(...) = ...
            gpu_thread tidx in [0, 63]<Default_GPU>:
              unrolled r8.ri.ri.ro in [0, 1]:
                tmp1_intm(...) = ...
          consume tmp1_intm:
            produce tmp1_intm:
              gpu_thread tidx in [0, 31]<Default_GPU>:
                tmp1_intm(...) = ...
              gpu_thread tidx in [0, 31]<Default_GPU>:
                unrolled r8.ri.ri.ri.ro in [0, 1]:
                  tmp1_intm(...) = ...
            consume tmp1_intm:
              produce tmp1_intm:
                gpu_thread tidx in [0, 15]<Default_GPU>:
                  tmp1_intm(...) = ...
                gpu_thread tidx in [0, 15]<Default_GPU>:
                  unrolled r8.ri.ri.ri.ri.ro in [0, 1]:
                    tmp1_intm(...) = ...
              consume tmp1_intm:
                produce tmp1_intm:
                  gpu_thread tidx in [0, 7]<Default_GPU>:
                    tmp1_intm(...) = ...
                  gpu_thread tidx in [0, 7]<Default_GPU>:
                    unrolled r8.ri.ri.ri.ri.ri.ro in [0, 1]:
                      tmp1_intm(...) = ...
                consume tmp1_intm:
                  produce tmp1_intm:
                    gpu_thread tidx in [0, 3]<Default_GPU>:
                      tmp1_intm(...) = ...
                    gpu_thread tidx in [0, 3]<Default_GPU>:
                      unrolled r8.ri.ri.ri.ri.ri.ri.ro in [0, 1]:
                        tmp1_intm(...) = ...
                  consume tmp1_intm:
                    produce tmp1_intm:
                      gpu_thread tidx in [0, 1]<Default_GPU>:
                        tmp1_intm(...) = ...
                      gpu_thread tidx in [0, 1]<Default_GPU>:
                        unrolled r8.ri.ri.ri.ri.ri.ri.ri.ro in [0, 1]:
                          tmp1_intm(...) = ...
                    consume tmp1_intm:
                      produce tmp1_intm:
                        gpu_thread tidx in [0, 0]<Default_GPU>:
                          tmp1_intm(...) = ...
                        gpu_thread tidx in [0, 0]<Default_GPU>:
                          unrolled r8.ri.ri.ri.ri.ri.ri.ri.ri.ro in [0, 1]:
                            tmp1_intm(...) = ...
                      consume tmp1_intm:
                        tmp1(...) = ...
    consume tmp1:
      out_ptr0(...) = ...

@alexreinking alexreinking self-assigned this Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants