Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Fix installation scripts and docs for dequantize gemm #20

Merged
merged 3 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!

## Tested Devices
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000 (Ada); for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).

## OP Implementation Examples
**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include:
Expand All @@ -24,7 +24,8 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp
- [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/)

Within the `examples` repository, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention.
Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.


## Benchmark Summary

Expand Down Expand Up @@ -84,8 +85,6 @@ In this section, you’ll learn how to write and execute a straightforward GEMM
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.

```python
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
Expand All @@ -103,7 +102,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Kernel configuration remains similar
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
Expand All @@ -122,14 +121,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# Clear local accumulation
T.clear(C_local)

for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, ko * block_K], A_shared)

# Demonstrate parallelized copy from global to shared for B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j]
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
Expand All @@ -141,7 +140,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main


# 1. Define the kernel (matmul) and compile/lower it into an executable module
# 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32)

# 2. Compile the kernel into a torch function
Expand All @@ -158,7 +157,7 @@ a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)


# Run the kernel through the Profiler
# Run the kernel through the JIT-compiled function
c = jit_kernel(a, b)

# Reference multiplication using PyTorch
Expand All @@ -172,7 +171,7 @@ print("Kernel output matches PyTorch reference.")
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)

# 5.Pofile latency with kernel
# 5.Pofile latency with the profiler
profiler = jit_kernel.get_profiler()

latency = profiler.do_bench()
Expand All @@ -189,8 +188,6 @@ In addition to GEMM, we provide a variety of examples to showcase the versatilit
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.

More operators will continuously be added.

---

TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS).
Expand Down
2 changes: 2 additions & 0 deletions examples/dequantize_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ def dequant_matmul(
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct[bx * block_N, by * block_M])
```

**Notes:** Dequantize GEMM with magic layout transformations to get optimal performance can be found at project [BitBLAS](https://github.com/microsoft/BitBLAS), example kernels can be found at `testing/python/kernel/test_tilelang_dequantize_gemm.py`, detailed explanation and examples is coming soon.
Loading
Loading