Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy issue when tensor dim is 1 #35

Closed
DD-DuDa opened this issue Jan 23, 2025 · 8 comments
Closed

Copy issue when tensor dim is 1 #35

DD-DuDa opened this issue Jan 23, 2025 · 8 comments
Labels
bug Something isn't working

Comments

@DD-DuDa
Copy link

DD-DuDa commented Jan 23, 2025

Assume we have Q tensor shape with [bs, 1, head, dim].
And we allocate a shared memory Q_shared [block_M, dim].

how to copy Q_shared[0, :] = Q[bid, 0, hid, :]?

# type: ignore

import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial

def flashdecoding(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, num_split, tune=False):
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
    shape_q = [batch, seqlen_q, heads, dim]
    shape_kv = [batch, seqlen_kv, heads, dim]
    part_shape = [batch, seqlen_q, heads, num_split, dim]
    dtype = "float16"
    accum_dtype = "float"

    def kernel_func(block_M, block_N):
        
        @T.macro
        def flash_attn_split(
            Q: T.Buffer(shape_q, dtype),
            K: T.Buffer(shape_kv, dtype),
            V: T.Buffer(shape_kv, dtype),
            glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
            Output_partial: T.Buffer(part_shape, dtype),
        ):
            print("flash_attn_split")
            with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128 * 2) as (bx, by, bz):
                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                O_shared = T.alloc_shared([block_M, dim], dtype)
                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                scores_max = T.alloc_fragment([block_M], accum_dtype)
                scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
                scores_scale = T.alloc_fragment([block_M], accum_dtype)
                scores_sum = T.alloc_fragment([block_M], accum_dtype)
                logsum = T.alloc_fragment([block_M], accum_dtype)

                mid = bx
                hid = by % heads
                bid = by // heads
                sid = bz

                # T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
                T.copy(Q[bid, 0, hid, :], Q_shared[0, :])
                # T.fill(acc_o, 0)
                # T.fill(logsum, 0)
                # T.fill(scores_max, -T.infinity(accum_dtype))

                # loop_range = (
                #     T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) 
                #     if is_casual else T.ceildiv((seqlen_kv // num_split), block_N)
                # )

                # for k in T.Pipelined(loop_range, num_stages=2):
                #     MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
        
        @T.prim_func
        def main(
                Q: T.Buffer(shape_q, dtype),
                K: T.Buffer(shape_kv, dtype),
                V: T.Buffer(shape_kv, dtype),
                glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
                Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
                Output: T.Buffer(shape_q, dtype),
        ):
            print("hello")
            flash_attn_split(Q, K, V, glse, Output_partial)

        return main

    def kernel(block_M, block_N):
        return kernel_func(block_M, block_N)

    return kernel

def ref_program(Q, K, V, casual):
    assert casual is False
    dim = Q.size(-1)
    scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
    return output

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=1, help='batch size')
    parser.add_argument('--heads', type=int, default=32, help='heads')
    parser.add_argument('--seqlen_kv', type=int, default=4096, help='sequence length')
    parser.add_argument('--dim', type=int, default=128, help='dim')
    parser.add_argument('--is_casual', action='store_true', help='causal')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()

    batch, heads, seqlen_kv, dim, is_casual = args.batch, args.heads, args.seqlen_kv, args.dim, args.is_casual
    seqlen_q   = 1
    num_splits = 4

    program = flashdecoding(
                batch, heads, seqlen_q, seqlen_kv, dim, is_casual, num_splits, tune=args.tune)(
                block_M=128, block_N=128)
    jit_kernel = tilelang.JITKernel(program, out_idx=[5], target="cuda")

    q = torch.randn(batch, seqlen_q, heads, dim, dtype=torch.float16, device='cuda')
    k = torch.randn(batch, seqlen_kv, heads, dim, dtype=torch.float16, device='cuda')
    v = torch.randn(batch, seqlen_kv, heads, dim, dtype=torch.float16, device='cuda')
    glse = torch.empty(batch, heads, num_splits, seqlen_q, dtype=torch.float16, device='cuda')
    output_partial = torch.empty(batch, seqlen_q, heads, num_splits, dim, dtype=torch.float16, device='cuda')

    out_ref = ref_program(q, k, v, is_casual)
    out_flash = jit_kernel(q, k, v, glse, output_partial)

    print(f"out_ref vs out_flash: {(out_ref - out_flash).abs().mean().item()}")

I got error:

Traceback (most recent call last):
File "/home/shijiecao/Projects/BitAttn/tilelang/mha_kvcache.py", line 187, in
jit_kernel = tilelang.JITKernel(program, out_idx=[5], target="cuda")
File "/home/shijiecao/miniconda3/envs/bit/lib/python3.10/site-packages/tilelang/jit/kernel.py", line 75, in init
adapter = self.compile_and_create_adapter(func)
File "/home/shijiecao/miniconda3/envs/bit/lib/python3.10/site-packages/tilelang/jit/kernel.py", line 120, in compile_and_create_adapter
rt_mod, params = tilelang.lower(tilelang_func, target=target)
File "/home/shijiecao/miniconda3/envs/bit/lib/python3.10/site-packages/tilelang/engine/lower.py", line 223, in lower
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
File "/home/shijiecao/miniconda3/envs/bit/lib/python3.10/site-packages/tilelang/3rdparty/tvm/python/tvm/ffi/ctypes/packed_func.py", line 239, in call
raise_last_ffi_error()
File "/home/shijiecao/miniconda3/envs/bit/lib/python3.10/site-packages/tilelang/3rdparty/tvm/python/tvm/ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
31: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module ()(tvm::IRModule, tvm::Target)>(tvm::runtime::Module ()(tvm::IRModule, tvm::Target), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
30: tvm::codegen::BuildTileLangCUDA(tvm::IRModule, tvm::Target)
29: tvm::codegen::CodeGenTileLangCUDA::AddFunction(tvm::tir::PrimFunc const&)
28: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
27: non-virtual thunk to tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::DeclBufferNode const*)
26: non-virtual thunk to tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::DeclBufferNode const*)
25: tvm::codegen::CodeGenTileLangCUDA::VisitStmt
(tvm::tir::AttrStmtNode const*)
24: tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::AttrStmtNode const*)
23: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
22: tvm::codegen::CodeGenTileLangCUDA::VisitStmt
(tvm::tir::AllocateNode const*)
21: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
20: tvm::codegen::CodeGenTileLangCUDA::VisitStmt
(tvm::tir::AttrStmtNode const*)
19: tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::AttrStmtNode const*)
18: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
17: tvm::codegen::CodeGenTileLangCUDA::VisitStmt
(tvm::tir::AttrStmtNode const*)
16: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
15: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
14: tvm::codegen::CodeGenTileLangCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
13: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
12: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
11: tvm::codegen::CodeGenTileLangCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
10: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
9: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
8: tvm::codegen::CodeGenTileLangCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
7: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
6: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
5: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::IfThenElseNode const*)
4: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
3: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::BufferStoreNode const*)
2: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
1: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
0: tvm::codegen::CodeGenTileLangCUDA::VisitExpr_(tvm::tir::RampNode const*, std::ostream&)
File "/root/TileLang/src/target/codegen_cuda.cc", line 1257
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.

@LeiWang1999
Copy link
Contributor

@DD-DuDa Thanks for your reporting, would you mind provide the entire scripts to help us reproduce?

@DD-DuDa
Copy link
Author

DD-DuDa commented Jan 23, 2025

Yeah, I've edited and provided the whole code.

@LeiWang1999
Copy link
Contributor

likely due to some bugs of liveness, for example, consider the following simplified program you provide:

@T.prim_func
def main(
        Q: T.Buffer(shape_q, dtype),
        K: T.Buffer(shape_kv, dtype),
        V: T.Buffer(shape_kv, dtype),
        glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
        Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
        Output: T.Buffer(shape_q, dtype),
):
    with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128 * 2) as (bx, by, bz):
        Q_shared = T.alloc_shared([block_M, dim], dtype)

        hid = by % heads
        bid = by // heads

        T.copy(Q[bid, 0, hid, :], Q_shared[0, :])
        for d in T.serial(dim):
            Q_shared[0, d] = Q[bid, 0, hid, d]

When you generate the kernel code using print(jit_kernel.get_kernel_source()), the output is as follows:

extern "C" __global__ void __launch_bounds__(256) main_kernel(half_t* __restrict__ Q) {
  extern __shared__ __align__(1024) half_t Q_shared[];
  if (((int)threadIdx.x) < 16) {
    *(uint4*)(Q_shared + (((int)threadIdx.x) * 8)) = *(uint4*)(Q + ((((int)blockIdx.y) * 128) + (((int)threadIdx.x) * 8)));
  }
  for (int d = 1; d < 128; ++d) {
    Q_shared[d] = Q[((((int)blockIdx.y) * 128) + d)];
  }
}

In this generated code, the first copy block behaves as expected and aligns with the intended functionality.
So, it’s possible that the issue may resolve itself if you uncomment certain parts of the program.

@LeiWang1999
Copy link
Contributor

but it's also important for us to discover where the bug locates

@LeiWang1999
Copy link
Contributor

One debug trick is that we can insert debug print at tilelang/engine/lower.py to see the lowered ir module:

device_mod = tir.transform.Filter(is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
print(device_mod)
if target.kind.name == "cuda":
    # Debug comments to get the code
    # code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target)
    device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)

for the frist program:

@I.ir_module
class Module:
    @T.prim_func
    def main_kernel(Q: T.handle("float16", "global")):
        Q_1 = T.decl_buffer((16777216,), "float16", data=Q)
        Q_shared = T.handle("float16", "shared.dyn")
        Q_shared_1 = T.decl_buffer((131072,), "float16", data=Q_shared, scope="shared.dyn")
        bx = T.launch_thread("blockIdx.x", 32)
        Q_shared = T.allocate([131072], "float16", "shared.dyn")
        by = T.launch_thread("blockIdx.y", 32)
        bz = T.launch_thread("blockIdx.z", 4)
        v = T.launch_thread("threadIdx.x", 256)
        v_1 = T.launch_thread("threadIdx.y", 1)
        v_2 = T.launch_thread("threadIdx.z", 1)
        if v < 16:
            Q_shared_1[v * 64:v * 64 + 72:9] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8]

for the last program:

@I.ir_module
class Module:
    @T.prim_func
    def main_kernel(Q: T.handle("float16", "global")):
        Q_1 = T.decl_buffer((16777216,), "float16", data=Q)
        Q_shared = T.handle("float16", "shared.dyn")
        Q_shared_1 = T.decl_buffer((16384,), "float16", data=Q_shared, scope="shared.dyn")
        bx = T.launch_thread("blockIdx.x", 32)
        Q_shared = T.allocate([16384], "float16", "shared.dyn")
        by = T.launch_thread("blockIdx.y", 32)
        bz = T.launch_thread("blockIdx.z", 4)
        v = T.launch_thread("threadIdx.x", 256)
        v_1 = T.launch_thread("threadIdx.y", 1)
        v_2 = T.launch_thread("threadIdx.z", 1)
        if v < 16:
            Q_shared_1[v * 8:v * 8 + 8] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8]
        for d in range(128):
            Q_shared_1[d] = Q_1[by * 128 + d]

@LeiWang1999
Copy link
Contributor

The problem behinds tir.transform.VectorizeLoop.

print("Before vectorize loop \n", mod)
mod = tir.transform.VectorizeLoop()(mod)
print("After vectorize loop \n", mod)
@I.ir_module
class Module:
    @T.prim_func
    def main(Q: T.Buffer((1, 4096, 32, 128), "float16"), K: T.Buffer((1, 4096, 32, 128), "float16"), V: T.Buffer((1, 4096, 32, 128), "float16"), glse: T.Buffer((1, 32, 4, 4096), "float16"), Output_partial: T.Buffer((1, 4096, 32, 4, 128), "float16"), Output: T.Buffer((1, 4096, 32, 128), "float16")):
        if v < 16:
            i = T.int32()
            T.attr(i, "pragma_unroll_explicit", T.bool(False))
            for i in T.vectorized(8):
                Q_shared = T.allocate([16384], "float16", "shared.dyn")
                Q_shared_1 = T.Buffer((16384,), "float16", data=Q_shared, scope="shared.dyn")
                Q_1 = T.Buffer((16777216,), "float16", data=Q.data)
                Q_shared_1[v * 8 + i] = Q_1[by * 128 + v * 8 + i]

After vectorize loop 
 # from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(Q: T.Buffer((1, 4096, 32, 128), "float16"), K: T.Buffer((1, 4096, 32, 128), "float16"), V: T.Buffer((1, 4096, 32, 128), "float16"), glse: T.Buffer((1, 32, 4, 4096), "float16"), Output_partial: T.Buffer((1, 4096, 32, 4, 128), "float16"), Output: T.Buffer((1, 4096, 32, 128), "float16")):
        if v < 16:
            i = T.int32()
            T.attr(i, "pragma_unroll_explicit", T.bool(False))
            Q_shared = T.allocate([131072], "float16", "shared.dyn")
            Q_shared_1 = T.Buffer((131072,), "float16", data=Q_shared, scope="shared.dyn")
            Q_1 = T.Buffer((16777216,), "float16", data=Q.data)
            Q_shared_1[v * 64:v * 64 + 72:9] = Q_1[by * 128 + v * 8:by * 128 + v * 8 + 8]

@DD-DuDa
Copy link
Author

DD-DuDa commented Jan 23, 2025

Got it! I learned a lot for that. Thank you!

@LeiWang1999 LeiWang1999 added the bug Something isn't working label Jan 24, 2025
@LeiWang1999
Copy link
Contributor

closed as has been resolved :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants