Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Tensorize when loop extent = 1 #381

Open
vinx13 opened this issue May 3, 2021 · 4 comments
Open

[BUG] Tensorize when loop extent = 1 #381

vinx13 opened this issue May 3, 2021 · 4 comments

Comments

@vinx13
Copy link
Collaborator

vinx13 commented May 3, 2021

Tensorize currently doesn't work when axis of a buffer has extent = 1. See the example.

import tvm
from tvm import te, tir
from tvm.script import ty

@tvm.script.tir
def intrin_mma_desc(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32, 1), "float32",  scope="global", offset_factor=1)
    B = tir.match_buffer(b, (32, 1), "float32",  scope="global", offset_factor=1)
    C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
    with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.bind(vk, 0)
        tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32,vk:vk+1], B[vj:vj+32,vk:vk+1]])
        tir.writes(C[vi:vi+32, vj:vj+32])
        for i, j, k in tir.grid(32, 32, 1):
            with tir.block([32, 32, tir.reduce_axis(0, 1)], "B") as [vii, vjj, vkk]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                tir.bind(vkk, vk)
                C[vii, vjj] = C[vii, vjj] + A[vii,vkk] * B[vjj,vkk]


@tvm.script.tir
def intrin_mma_impl(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32, 1), "float32",  scope="global", offset_factor=1)
    B = tir.match_buffer(b, (32, 1), "float32",  scope="global", offset_factor=1)
    C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
    with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.bind(vk, 0)
        tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32, vk:vk+1], B[vj:vj+32,vk:vk+1]])
        tir.writes(C[vi:vi+32, vj:vj+32])
        tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 1024, A.data, A.elem_offset // 32, B.data, B.elem_offset // 32, dtype='handle'))


@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


def main():
    mod = tvm.script.create_module({'main': matmul})
    s = tir.Schedule(mod)
    C = s.get_block('C')
    i, j, k = s.get_axes(C)
    i0, i1 = s.split(i, factor=32)
    j0, j1 = s.split(j, factor=32)
    k0, k1 = s.split(k, factor=1)
    s.reorder(i0, j0, k0, i1, j1, k1)
    s.tensorize(i1, tir.TensorIntrin(intrin_mma_desc, intrin_mma_impl))

    print(tvm.script.asscript(s.mod['main']))


main()

The above code doesn't work because mismatch between loop and tensor intrinsic description. The loop k1 is eliminated from the block iter var (this is because of this).

If I remove this part of code, we still need to fix the patten matcher here https://github.com/Hzfengsy/tvm-tensorir/blob/main/src/tir/schedule/primitives/blockize_tensorize.cc#L73 because the original loop after blockize will be

block C(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128)){
  bind(vi, ((vio*32) + i0_inner))
  bind(vj, ((vjo*32) + i1_inner))
  bind(vk, vko)  # the inner loop var of extent 1 is still eliminated.
  reads([C[vi, vj], A[vi, vk], B[vj, vk]])
  writes([C[vi, vj]])
  C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}

and as a result B.elem_offset is lowered to get_elem_offset(B[vjo * 32, 0] instead of get_elem_offset(B[vjo * 32, vko] because the detected binding of vk is incorrect.

The design question here is whether we should eliminated loop of extent 1 during blockize and tensorize.

@junrushao
Copy link
Member

Is this issue addressed yet? If so let's close this issue

@vinx13
Copy link
Collaborator Author

vinx13 commented Oct 27, 2021

I'll double check this when upstreaming

@Hzfengsy
Copy link
Member

cc @spectrometerHBH

@Hzfengsy
Copy link
Member

I guess it is because of affine map

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants