-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copy issue when tensor dim is 1 #35
Comments
@DD-DuDa Thanks for your reporting, would you mind provide the entire scripts to help us reproduce? |
Yeah, I've edited and provided the whole code. |
likely due to some bugs of liveness, for example, consider the following simplified program you provide: @T.prim_func
def main(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Buffer(shape_q, dtype),
):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128 * 2) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
hid = by % heads
bid = by // heads
T.copy(Q[bid, 0, hid, :], Q_shared[0, :])
for d in T.serial(dim):
Q_shared[0, d] = Q[bid, 0, hid, d] When you generate the kernel code using extern "C" __global__ void __launch_bounds__(256) main_kernel(half_t* __restrict__ Q) {
extern __shared__ __align__(1024) half_t Q_shared[];
if (((int)threadIdx.x) < 16) {
*(uint4*)(Q_shared + (((int)threadIdx.x) * 8)) = *(uint4*)(Q + ((((int)blockIdx.y) * 128) + (((int)threadIdx.x) * 8)));
}
for (int d = 1; d < 128; ++d) {
Q_shared[d] = Q[((((int)blockIdx.y) * 128) + d)];
}
} In this generated code, the first copy block behaves as expected and aligns with the intended functionality. |
but it's also important for us to discover where the bug locates |
One debug trick is that we can insert debug print at device_mod = tir.transform.Filter(is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
print(device_mod)
if target.kind.name == "cuda":
# Debug comments to get the code
# code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target)
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) for the frist program: @I.ir_module
class Module:
@T.prim_func
def main_kernel(Q: T.handle("float16", "global")):
Q_1 = T.decl_buffer((16777216,), "float16", data=Q)
Q_shared = T.handle("float16", "shared.dyn")
Q_shared_1 = T.decl_buffer((131072,), "float16", data=Q_shared, scope="shared.dyn")
bx = T.launch_thread("blockIdx.x", 32)
Q_shared = T.allocate([131072], "float16", "shared.dyn")
by = T.launch_thread("blockIdx.y", 32)
bz = T.launch_thread("blockIdx.z", 4)
v = T.launch_thread("threadIdx.x", 256)
v_1 = T.launch_thread("threadIdx.y", 1)
v_2 = T.launch_thread("threadIdx.z", 1)
if v < 16:
Q_shared_1[v * 64:v * 64 + 72:9] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8] for the last program: @I.ir_module
class Module:
@T.prim_func
def main_kernel(Q: T.handle("float16", "global")):
Q_1 = T.decl_buffer((16777216,), "float16", data=Q)
Q_shared = T.handle("float16", "shared.dyn")
Q_shared_1 = T.decl_buffer((16384,), "float16", data=Q_shared, scope="shared.dyn")
bx = T.launch_thread("blockIdx.x", 32)
Q_shared = T.allocate([16384], "float16", "shared.dyn")
by = T.launch_thread("blockIdx.y", 32)
bz = T.launch_thread("blockIdx.z", 4)
v = T.launch_thread("threadIdx.x", 256)
v_1 = T.launch_thread("threadIdx.y", 1)
v_2 = T.launch_thread("threadIdx.z", 1)
if v < 16:
Q_shared_1[v * 8:v * 8 + 8] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8]
for d in range(128):
Q_shared_1[d] = Q_1[by * 128 + d] |
The problem behinds print("Before vectorize loop \n", mod)
mod = tir.transform.VectorizeLoop()(mod)
print("After vectorize loop \n", mod) @I.ir_module
class Module:
@T.prim_func
def main(Q: T.Buffer((1, 4096, 32, 128), "float16"), K: T.Buffer((1, 4096, 32, 128), "float16"), V: T.Buffer((1, 4096, 32, 128), "float16"), glse: T.Buffer((1, 32, 4, 4096), "float16"), Output_partial: T.Buffer((1, 4096, 32, 4, 128), "float16"), Output: T.Buffer((1, 4096, 32, 128), "float16")):
if v < 16:
i = T.int32()
T.attr(i, "pragma_unroll_explicit", T.bool(False))
for i in T.vectorized(8):
Q_shared = T.allocate([16384], "float16", "shared.dyn")
Q_shared_1 = T.Buffer((16384,), "float16", data=Q_shared, scope="shared.dyn")
Q_1 = T.Buffer((16777216,), "float16", data=Q.data)
Q_shared_1[v * 8 + i] = Q_1[by * 128 + v * 8 + i]
After vectorize loop
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(Q: T.Buffer((1, 4096, 32, 128), "float16"), K: T.Buffer((1, 4096, 32, 128), "float16"), V: T.Buffer((1, 4096, 32, 128), "float16"), glse: T.Buffer((1, 32, 4, 4096), "float16"), Output_partial: T.Buffer((1, 4096, 32, 4, 128), "float16"), Output: T.Buffer((1, 4096, 32, 128), "float16")):
if v < 16:
i = T.int32()
T.attr(i, "pragma_unroll_explicit", T.bool(False))
Q_shared = T.allocate([131072], "float16", "shared.dyn")
Q_shared_1 = T.Buffer((131072,), "float16", data=Q_shared, scope="shared.dyn")
Q_1 = T.Buffer((16777216,), "float16", data=Q.data)
Q_shared_1[v * 64:v * 64 + 72:9] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8] |
Got it! I learned a lot for that. Thank you! |
closed as has been resolved :) |
Assume we have Q tensor shape with [bs, 1, head, dim].
And we allocate a shared memory Q_shared [block_M, dim].
how to copy Q_shared[0, :] = Q[bid, 0, hid, :]?
I got error:
The text was updated successfully, but these errors were encountered: