Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing Shared Memory Usage for a Tensor Product in Halide on GPU #8427

Closed
nullplay opened this issue Sep 27, 2024 · 0 comments
Closed

Optimizing Shared Memory Usage for a Tensor Product in Halide on GPU #8427

nullplay opened this issue Sep 27, 2024 · 0 comments

Comments

@nullplay
Copy link

Hi,

I'm implementing a tensor product operation in Halide that involves gathering inputs and scattering the final output on a GPU. I'm aiming to optimize shared memory usage for better performance, but I'm encountering some challenges.

Here's a reproduce of my Halide Generator code:

import halide as hl

@hl.generator(name="test")
class Test:
    maxPos = hl.InputScalar(hl.Int(32))
    outSize = hl.InputScalar(hl.Int(32))

    imap = hl.InputBuffer(hl.Int(32), 2)
    omap = hl.InputBuffer(hl.Int(32), 2)

    weight = hl.InputBuffer(hl.Float(32), 3)
    input = hl.InputBuffer(hl.Float(32), 2)
    output = hl.OutputBuffer(hl.Float(32), 2)

    def generate(g):
        imap = g.imap
        omap = g.omap
        weight = g.weight
        input = g.input
        output = g.output
        maxPos = g.maxPos
        outSize = g.outSize

        weight.dim(1).set_bounds(0,64)
        weight.dim(0).set_bounds(0,64)

        # Variable Definition
        o,n,c,p,m = hl.vars("o n c p m")
        m1, m0 = hl.vars("m1 m0")
        p1, p0 = hl.RVar("p1"), hl.RVar("p0")
        c1, c0 = hl.RVar("c1"), hl.RVar("c0")

        # Algorithm
        # 1. Gather Input
        gather_input = hl.Func("GatherInput")
        gather_input[c, p, n] = input[c, hl.unsafe_promise_clamped(imap[p, n], 0, input.dim(0).max())]

        # 2. Load Weight (Identity)
        gather_weight = hl.Func("GatherWeight")
        gather_weight[m, c, n] = weight[m, c, n]

        # 3. Tensor Product
        r1 = hl.RDom([(0, weight.dim(1).extent())])
        product = hl.Func("Product")
        product[m, p, n] = 0.0
        product[m, p, n] += gather_input[r1.x, p, n] * gather_weight[m, r1.x, n]

        # 4. Scatter Product to Output
        r2 = hl.RDom([(0, weight.dim(2).extent()), (0, maxPos)])
        output[m,o] = 0.0
        output[m, hl.unsafe_promise_clamped(omap[r2.y, r2.x], 0, outSize)] += product[m, r2.y, r2.x]


        # Schedule
        (output
            .reorder(m, o)
            .gpu_blocks(o).gpu_threads(m)
        )

        (output.update(0)
            .tile(r2.y, m, p1, m1, 128, 32)
            .tile(p1, m1, p0, m0, 4, 1)
            .reorder(m0, p0, m1, p1, m, r2.y, r2.x)
            .atomic().gpu_blocks(m,r2.y).gpu_threads(m1,p1)
        )

        (product
            .compute_at(output, m1)
            .store_in(hl.MemoryType.Register)
            .update(0)
            .split(r1.x, c1, c0, 16)
            .reorder(m,c0,p,c1,n)
        )

        (gather_weight
            .compute_at(product, c1)
            .store_in(hl.MemoryType.GPUShared)
        )



with hl.GeneratorContext(hl.Target("host-cuda")):
    gen = Test()
f = gen.compile_to_callable()

Outcome :

produce output:
  gpu_block o<Default_GPU>:
    gpu_thread m<Default_GPU>:
      output(...) = ...
  for r39:
    gpu_block r39.r39<Default_GPU>:
      gpu_block m.m<Default_GPU>:
        gpu_thread r39.p1.p1 in [0, 31]<Default_GPU>:
          gpu_thread m.m1.m1 in [0, 31]<Default_GPU>:
            produce Product:
              for p:
                Product(...) = ...
              for r28.c1:
                produce GatherWeight:
                  for c:
                    GatherWeight(...) = ...
                consume GatherWeight:
                  for p:
                    for r28.c0 in [0, 15]:
                      Product(...) = ...
            consume Product:
              for r39.p1.p0 in [0, 3]:
                output(...) = ...


let t197 = (maxPos + 127)/128
  let t199 = maxPos/128
  let t198 = (output.extent.0 + 31)/32
  let t201 = output.min.1*output.stride.1
  let t200 = (input.min.1*input.stride.1) + input.min.0
  for (output.s1.r39$x, 0, weight.extent.2) {
   let t204 = ((output.s1.r39$x - omap.min.1)*omap.stride.1) - omap.min.0
   let t203 = ((output.s1.r39$x - imap.min.1)*imap.stride.1) - imap.min.0
   let t202 = (output.s1.r39$x*weight.stride.2) + output.min.0
   gpu_block<CUDA> (output.s1.r39$y.r39$y.block_id_y, 0, t197) {
    gpu_block<CUDA> (output.s1.m.m.block_id_x, 0, t198) {
     allocate GatherWeight.0[float32 * 16384] in GPUShared
     gpu_thread<CUDA> (.thread_id_y, 0, 32) {
      gpu_thread<CUDA> (.thread_id_x, 0, 32) {
       if (output.s1.r39$y.r39$y.block_id_y < t199) {
        allocate Product.0[float32 * 4] in Register
        produce Product {
         let Product.s0.p.loop_extent.s = (maxPos - (output.s1.r39$y.r39$y.block_id_y*128)) - (.thread_id_y*4)
         let t205 = min(Product.s0.p.loop_extent.s, 4)
         for (Product.s0.p.rebased, 0, t205) {
          Product.0[Product.s0.p.rebased] = 0.000000f
         }
         let t173.s = (output.s1.m.m.block_id_x*32) + t202
         let t208 = min(Product.s0.p.loop_extent.s, 4)
         let t209 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t203
         let t206 = .thread_id_x + t173.s
         let t207 = (.thread_id_y*32) + .thread_id_x
         for (Product.s1.r28$x.c1, 0, 4) {
          produce GatherWeight {
           let t210 = Product.s1.r28$x.c1*16
           for (GatherWeight.s0.c.rebased, 0, 16) {
            GatherWeight.0[(GatherWeight.s0.c.rebased*1024) + t207] = weight[((GatherWeight.s0.c.rebased + t210)*weight.stride.1) + t206]
           }
          }
          consume GatherWeight {
           let t178 = (Product.s1.r28$x.c1*16) - t200
           for (Product.s1.p.rebased, 0, t208) {
            let t179 = Product.s1.p.rebased + t209
            for (Product.s1.r28$x.c0, 0, 16) {
             Product.0[Product.s1.p.rebased] = Product.0[Product.s1.p.rebased] + (input[((imap[t179]*input.stride.1) + t178) + Product.s1.r28$x.c0]*GatherWeight.0[(Product.s1.r28$x.c0*1024) + t207])
            }
           }
          }
         }
        }
        consume Product {
         let t181.s = (output.s1.m.m.block_id_x*32) - t201
         let t212 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t204
         let t211 = .thread_id_x + t181.s
         for (output.s1.r39$y.p1.p0, 0, 4) {
          let t139 = (omap[output.s1.r39$y.p1.p0 + t212]*output.stride.1) + t211
          let t140 = Product.0[output.s1.r39$y.p1.p0]
          atomic (output) {
            output[t139] = output[t139] + t140
          }
         }
        }
        free Product.0

Objective:

I want to achieve the following optimizations on the GPU:

  1. Accumulate the product in a 4x1 (p0 x m0) register block.

    • This is successfully achieved using:
      product.compute_at(output, m1).store_in(hl.MemoryType.Register)
  2. Load gather_weight into shared memory at the outer reduction loop (c1) in product.

    • Inside c1, gather_weight requires m1(32) x c0(16) = 512 elements.
    • Per GPU block, there are m1(32) x p1(32) GPU threads. Since c0 is independent of the p1 dimension (the y-axis of GPU threads), we can reuse gather_weight across threads if we load it into shared memory.
    • Ideal Scenario:
      • Shared Memory Allocation: Allocate only m1(32) x c0(16) = 512 elements.
      • Data Loading: Use only a subset of GPU threads, such as m1(32) x (p1/2)(16), to load gather_weight into shared memory.

Issue Encountered:

  • Current Observation:

    • The actual shared memory allocation is significantly larger than expected (e.g., 16384 elements).
    • Only the x-axis of GPU threads is used for loading gather_weight into shared memory.
  • Hypothesis:

    • Halide might not recognize that gather_weight can be reused across independent thread variables (p1), leading to a larger shared memory allocation (m1 x c0 x p1).

Attempted Solution:

I tried adjusting the schedule to bring gather_weight computation at output instead of product:

(gather_weight
            .compute_at(output, m)
            .store_in(hl.MemoryType.GPUShared)
            .split(c, wc1, wc0, 16)
            .split(m, m1, m0, 32)
            .gpu_threads(m0, wc0)
)

loop nest and conceptual stmt :

produce output:
  gpu_block o<Default_GPU>:
    gpu_thread m<Default_GPU>:
      output(...) = ...
  for r39:
    gpu_block r39.r39<Default_GPU>:
      gpu_block m.m<Default_GPU>:
        produce GatherWeight:
          for c.wc1:
            gpu_thread c.wc0 in [0, 15]<Default_GPU>:
              gpu_thread m.m0 in [0, 31]<Default_GPU>:
                GatherWeight(...) = ...
        consume GatherWeight:
          gpu_thread r39.p1.p1 in [0, 31]<Default_GPU>:
            gpu_thread m.m1.m1 in [0, 31]<Default_GPU>:
              produce Product:
                for p:
                  Product(...) = ...
                for r28.c1:
                  for p:
                    for r28.c0 in [0, 15]:
                      Product(...) = ...
              consume Product:
                for r39.p1.p0 in [0, 3]:
                  output(...) = ...


let t160 = (maxPos + 127)/128
  let t161 = (output.extent.0 + 31)/32
  let t163 = output.min.1*output.stride.1
  let t162 = (input.min.1*input.stride.1) + input.min.0
  for (output.s1.r39$x, 0, weight.extent.2) {
   let t166 = ((output.s1.r39$x - omap.min.1)*omap.stride.1) - omap.min.0
   let t165 = ((output.s1.r39$x - imap.min.1)*imap.stride.1) - imap.min.0
   let t164 = (output.s1.r39$x*weight.stride.2) + output.min.0
   gpu_block<CUDA> (output.s1.r39$y.r39$y.block_id_y, 0, t160) {
    gpu_block<CUDA> (output.s1.m.m.block_id_x, 0, t161) {
     allocate GatherWeight.0[float32 * 2048] in GPUShared
     gpu_thread<CUDA> (.thread_id_y, 0, 32) {
      gpu_thread<CUDA> (.thread_id_x, 0, 32) {
       allocate Product.0[float32 * 4] in Register
       if (.thread_id_y < 16) {
        produce GatherWeight {
         let t143.s = (output.s1.m.m.block_id_x*32) + t164
         let t167 = .thread_id_x + t143.s
         for (GatherWeight.s0.c.wc1, 0, 4) {
          let t158 = (GatherWeight.s0.c.wc1*16) + .thread_id_y
          GatherWeight.0[(t158*32) + .thread_id_x] = weight[(t158*weight.stride.1) + t167]
         }
        }
       }
       gpu_thread_barrier(2)
       consume GatherWeight {
        produce Product {
         let Product.s0.p.loop_extent.s = (maxPos - (output.s1.r39$y.r39$y.block_id_y*128)) - (.thread_id_y*4)
         let t168 = min(Product.s0.p.loop_extent.s, 4)
         for (Product.s0.p.rebased, 0, t168) {
          Product.0[Product.s0.p.rebased] = 0.000000f
         }
         let t169 = min(Product.s0.p.loop_extent.s, 4)
         let t170 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t165
         for (Product.s1.r28$x.c1, 0, 4) {
          let t148 = (Product.s1.r28$x.c1*16) - t162
          let t171 = Product.s1.r28$x.c1*16
          for (Product.s1.p.rebased, 0, t169) {
           let t151 = Product.s1.p.rebased + t170
           for (Product.s1.r28$x.c0, 0, 16) {
            Product.0[Product.s1.p.rebased] = Product.0[Product.s1.p.rebased] + (input[((imap[t151]*input.stride.1) + t148) + Product.s1.r28$x.c0]*GatherWeight.0[((Product.s1.r28$x.c0 + t171)*32) + .thread_id_x])
           }
          }
         }
        }
        consume Product {
         let output.s1.r39$y.p1.p0.epilogue.s = maxPos - (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4)
         let t154.s = (output.s1.m.m.block_id_x*32) - t163
         let t172 = max(min(output.s1.r39$y.p1.p0.epilogue.s, 4), 0)
         let t174 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t166
         let t173 = .thread_id_x + t154.s
         for (output.s1.r39$y.p1.p0, 0, t172) {
          let t111 = (omap[output.s1.r39$y.p1.p0 + t174]*output.stride.1) + t173
          let t112 = Product.0[output.s1.r39$y.p1.p0]
          atomic (output) {
            output[t111] = output[t111] + t112
          }
         }
        }
        free Product.0
       }
      }
     }
     free GatherWeight.0
    }
   }
  }
 }
}
  • Result:

    • The loop nest now appears closer to the desired structure.
    • Only half of thread_id_y is involved in loading data into shared memory.
  • Remaining Issue:

    • The shared memory still allocates an entire m0 x c0 x c1 block, which is larger than necessary (m0 x c0).
    • Ideally, I want the shared memory loading to happen within the loop over c1, loading only a block of size m0 x c0 = 512 elements.

Is there a way to adjust the Halide schedule to achieve this shared memory usage? I think .compute_at(product, c1) is necessary at some point, but I don't know how to bring this shared memory loads inside c1 with my requirements. I feel I'm almost there, or is this type of loop nest what halide wasn't meant to designed for?

@nullplay nullplay closed this as completed Nov 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant