Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DISCUSS] Tensorization of Warp Level Primitives #371

Open
tqchen opened this issue Apr 23, 2021 · 4 comments
Open

[DISCUSS] Tensorization of Warp Level Primitives #371

tqchen opened this issue Apr 23, 2021 · 4 comments

Comments

@tqchen
Copy link
Contributor

tqchen commented Apr 23, 2021

AMD and Nvidia's MFMA(matrix multiplication operators) operates on the warp level. This creates some interesting challenges for tensorization, semantics checks and tensorization infra. This is a discuss issue that tries to capture some of these questions.

Case Study Example, AMD's mfma32x31x1_f32 instruction

In AMD's case, the GPU have warp size = 64. e.g. the operations are done collectively by 64 threads, where the input and outputs are distributed along the registers of each thread. To make the presentation simple, we will use the following notations

  • wid: warp index
  • Warp level view of the memory(where the last dimension is always wrap_size, and wid). for example the following code
warp float warpMatrixA[1][warp_size];

Will get lowered into the following code in the thread level view.

float registerA[1];

The AMD's mfma32x31x1_f32 is a batched matmul instruction that performs two matrix outer products, to see what happens, the instruction is equivalent to

Logical Semantics

This is a batch matrix multiplication that divides the warp data into 2 of 32x32 groups and perform the matmul

for b, i, j in grid(2, 32, 32):
  C[b, i, j] += A[b, i] * B[b, j]

# matrix form
for b in grid(2):
  C[b, :, :] += dot(A[b, :], B[b, :].T)

In order to implement the above logical semantics, the C[2, 32, 32] , B[2, 32] and A[2, 32] are stored as special registers in warp memory, using the following rule (<=> means the memory map relation, wid is the warp index):

for x, wid in grid(32, 64) 
   warpC[x][wid] <=> C[x//16, x % 16 // 4 * 8 + x % 4 + wid // 32 * 4, wid % 32]

for wid in grid(64) 
   warpA[wid] <=> A[wid/32,  wid%32]

for wid in grid(64) 
   warpB[wid] <=> B[wid/32,  wid%32]

Namely, the data are of A, B and C needs to be layed out in a special way in the warp level memory, which in term maps to the corresponding registers(by removing the wid component.

The actual gpu code looks like follows(use a simple example to illustrate the intrinsics)

// perform a 2 batches of 32x32 matmul
kernel mfma_kernel(float *globalA[2, 32], float* globalB[2, 32], float* globalC[2, 32, 32]) {
   // assume a single thread
   int wid = threadIdx.x;
   // only need to allocate one A register
   float rA[1], rb[1];
   // special register to store results, need 16*2 registers per thread to represent warpC
   special_result_float rC[32] = {0};
   
   rA[0] = globalA[wid/32, wid % 32];
   rB[0] = globalB[wid/32, wid % 32];
   // run the intrisnic
   __mfma_32x32x1_f32(rC, rA, rB);
  // store back
  for (int i = 0; i < 32; ++i) {
     global[i / 16, i % 16 / 4 * 8 + i % 4 + wid / 32 * 4, wid % 32] = rC[i];
  }
}

mfma_kernel<<sizeof(threadIdx.x)=64, sizeof(blockIdx.x)=1>>(globalA, globalB, globalC)

The above kernel performs

for b, i, j in grid(2, 32, 32):
    globalC[b, i, j] = globalA[b, i] * globalB[b, j] 

In order to perform the matrix multiplication(tensorization) we need to perform the following steps:

  • S1: Use batch matmul to implement matmul. This can be achieved by duplicating load B, e.g. make B[0, j] = B[1, j], this will results in a semantics of 64x32 matmul
  • S2: copy the data into warpA, warpB memory(from shared memory) under the specific layout specified above. Note that this needs to be done by each of the thread(the wid always needs to bind to the warp index when we copy data into it).
  • S3: perform the mfma intrinsic(by passing in warpA, warpB and warpC)
  • S4: lower the warpA, warpB, warpC into real registers, their load/store into per thread level load store(should be easy and would require the wid always bound to the warp index
  • S5: Get the result out from warpC into a possible global memory.

Using BatchMatMul Intrinsic to Implement Matmul

It is possible to use Batch matmul intrinsic above to implement matmul(by replicating one side of the element). The logic is as follows (defining BB, AA, CC as the inputs and outputs of the matmul):

# replicate BB on all batches 
for b, j in grid(2, 32):
    B[b, j] <=> BB[j] 

for b, j in grid(2, 32):
    A[b, j] <=> AA[b* 32 + j] 

for b, i, j in grid(2, 32, 32):
    C[b, i, j] <=> CC[b* 32 +i, j] 

Then we have the following relationship:

for b, i, j in grid(2, 32, 32):
  C[b, i, j] += A[b, i] * B[b, j]

maps to <=>

for b, i, j in grid(2, 32, 32):
  CC[b* 32 + i, j] += A[b* 32+ i] * BB[j]

Which is exactly a 64x32 matmul

Challenges and Questions

We can find the following challenges that arises when tensorizing a wrap level primitives.

  • C0: The GPU programming model is still at thread level, so we need to declare the special warp memory as registers on the thread
  • C1: There is non-trivial logical mapping between the semantics of tensor intrins, and how do we layout the memory. For example, we can directly fold the real layout when describing the tensor intrinsic, but that will make the reasoning more complicated
  • C2: There is a need to be able to cache read the data into warpA, warpB in the layout specific to the tensor intrinsic, before executing it.

It would be useful to discuss possible ways to solve these challenges, for example:

  • Q1: what is a more clean "warp-level" programming model looks like? should we declare a block with warp memory?
  • Q2: How to allow the declaration/plugin of the logical-physical layout mapping as described in the above case
  • Q3: How to make generic lowering to the original GPU programming model, where warp memory are declared as per thread registers and intrinsics are instructions that takes these registers.
@vinx13
Copy link
Collaborator

vinx13 commented Apr 24, 2021

@tqchen
Copy link
Contributor Author

tqchen commented May 4, 2021

def matmul():
    for i0, j0, k0 in grid(8, 8, 1):
        CC[i0, j0]+= AA[i0, k0] * BB[j0, k0]

def func():
    for i, j, k in grid(128, 128, 128):
        C[i, j]+= A[i, k] * B[j, k]


def func_step0():
    for i1, j1, k1 in grid(16, 16, 128):
        for i0, j0, k0 in grid(8, 8, 1):
            C[i, j]+= A[i, k] * B[j, k]


def func_step1():
    for i1, j1, k1 in grid(16, 16, 128):
        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[i1 * 8 + ia, k1 + ka]
        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]

def func_step2():
    for i1, j1, k1 in grid(16, 16, 128):
        for ia, ka in grid(8, 1):
            Awarp[ia % 2, ia // 4] = A[i1 * 8 + ia, k1 + ka]

        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[ia, ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]


def func_step3():
    for i1, j1, k1 in grid(16, 16, 128):

        for i in grid(2):
            for wid in thread_binding("warpIndex", 4):
                Awarp[i, wid] = A[i1 * 8 + wid*2 +i, k1 + ka]

        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[ia, ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]

@Hzfengsy
Copy link
Member

Hzfengsy commented May 8, 2021

I have thought about a new proposal for TensorCore. Would like to have some discussion :)

Main Idea: wmma load/store changes data layout.

Currently, we write load/store intrin desc like following codes:

with tir.block([16, 16], "store") as [vi, vj]:
    AA[vi, vj] = A[vi, vj]

However, the true behavior of load/store is that(assume that we have a 16*16 warp op):

with tir.block([16, 16], "store") as [vi, vj]:
    AA[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj]

Hardware behavior

The warp fragment memory is somehow continuous (at least at CUDA level).

With wmma API, we declare a warp memory using wmma::fragment[N] with N fragments, which is similar to float16 data[N][16][16]. Note that the memory is compact. Just like a packed_layout at the warp memory level.

Cache_read/write with re-layout support

To support this memory layout transformation during the schedule, we need to introduce a new primitive.
The current cache_read copies memory and lets all consumers read data from the cached memory in the same layout. however, we can enhance it with index shaffling (somehow like CUDA swizzle). Here is an example:

AA = s.cache_read(A, lambda i, j: i // 16, j // 16, i % 16, j % 16)

And the generated IR is

with tir.block([n, m]) as [i, j]:
    AA[i // 16, j // 16, i % 16, j % 16] = A[i, j]

Benefits

  1. No need for the affine map. The only thing we need is a mapping from i, j -> i // 16, j // 16, i % 16, j % 16 and bijective is not required.
  2. Enable storage_align and swizzle
  3. Native support for Tensorcore, no need to consider the warp during tensorize.
  4. May also work on other accelerators.

@vinx13
Copy link
Collaborator

vinx13 commented May 13, 2021

I have elaborated a bit the workflow:
in the schedule:

@tvm.script.tir
def intrin_desc(a: ty.handle, b: ty.handle, c: ty.handle):
  # desc in like valilla matmul, with special buffer scope
  A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
  B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
  C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
  with block('root', [16, 16, tir.reduce_axis(16)]) as [vi, vj, vk]:
    tir.bind(vi, 0)
    tir.bind(vj, 0)
    tir.bind(vk, 0)
    for i, j, k in tir.grid(16, 16, 16):
      with block('C',  [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
        tir.bind(vii, vi + i)
        tir.bind(vji, vj + j)
        tir.bind(vki, vk + k)
        C[vii, vki] += A[vii, vki] * B[vji, vki]

@tvm.script.tir
def intrin_impl(a: ty.handle, b: ty.handle, c: ty.handle):
  # calling warp level intrinsic
  A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
  B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
  C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
  with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
    tir.mma_16x16x16(A, B, C, A_frag_index, B_frag_index, C_frag_index) # fragment indices are computed based on elem_offset, such as A.elem_offset // 256

def schedule_fn(sch):
  # split i, j, k and reorder ...
  sch.reorder(i0, j0, k0, i1, j1, k1)
  AA = sch.cache_read(A, 0, 'warp.layoutA')
  BB = sch.cache_read(B, 0, 'warp.layoutB')
  CC = sch.cache_write(C, 0, 'warp.layoutC')
  sch.compute_at(CC, k0)
  sch.compute_at(AA, k0)
  sch.compute_at(BB, k0)
  sch.tensorize(CC, i1, tensor_intrin)

Special layout can be lowered during buffer flatten. Intrinsic mma_16x16x16 also needs to be lowered to use physical layout, it will become thread-level instructions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants