Skip to content

Commit

Permalink
Reformat *tune_gemm* files with Triton's pre-commit
Browse files Browse the repository at this point in the history
The following command was executed to reformat the files:
```
$ pre-commit run --files \
    python/perf-kernels/tune_gemm/* \
    python/perf-kernels/tune_gemm/utils/*
```
  • Loading branch information
brunomazzottiamd committed Aug 15, 2024
1 parent 11e4447 commit 4f75b0f
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 534 deletions.
30 changes: 15 additions & 15 deletions python/perf-kernels/tune_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ This means `BLOCK_SIZE_K` does not need to divide K dim.

### Differences between the tutorial

Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial),
Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial),
the matmul kernel used in the tuning script (referred as the kernel) does not
guard load along M and N dim
guard load along M and N dim
([this](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py#L282-L283) shows how this is done in the tutorial).
When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will
When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will
load out-of-bound data.
In most cases this is fine, since the kernel does masked store at the end.
However, this may lead to GPU memory access fault in some cases, especially
when the tensor is large.
when the tensor is large.
We will fix this issue in the future.


Expand Down Expand Up @@ -53,7 +53,7 @@ The following `options` are supported in the tuning mode
kernel launch is not supported.
- Parallel profiling of kernels: The tuning space is first divided into a number
of tasks, which is controlled by `--jobs n`. And all the tasks can be profiled in
parallel on a number of GPUs in the system. There are two ways to specify which
parallel on a number of GPUs in the system. There are two ways to specify which
GPU(s) we want to use for profiling. Note that these flags cannot be use together.
By default, only one task is generated and profiled on GPU0.
- `--ngpus n`: GPU 0,1,.., n-1 will be used.
Expand Down Expand Up @@ -138,7 +138,7 @@ The supported `options` are as followings
The default value is 1000.
- `--icahe`: same as tuning mode
- `--rotating_tensor SIZE`: same as tuning mode


## Tuning script implementation overview

Expand Down Expand Up @@ -178,7 +178,7 @@ Workflow of the tuning process
1. Generate the full tuning space. For now the `range`s for each tuning parameter are hard-coded
2. Prune the tuning space according to the current GEMM size and some rules
- BLOCK_SIZE must be equal or larger than the mfma instruction size.
- SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel.
- SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel.
- When split-k is not needed, i.e. both M and N are large, it must be 1
- GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1
- When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. **Needs further investigation**
Expand All @@ -188,7 +188,7 @@ Workflow of the tuning process
2. Generate `matmul` function for each config in a similar way
3. Generate `try_config` functions for each `matmul` function.
4. Generate `test_gemm`, which does
1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel.
1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel.
2. Call each `matmul` function in a for loop of 10 iterations
5. Generate `main` function
4. Run the generated script with 16 workers. This will compile all kernels in parallel.
Expand All @@ -203,7 +203,7 @@ Workflow of the tuning process
The provided types must be one of ['fp32', 'fp16', 'bf16', 'fp8', 'bf8', 'int8'].
- Row/col major-ness of operand a and b can be provided as `-col_a` and `-col_b`.
If set, it means the corresponding operand is column major.
The major-ness is considered as problem input.
The major-ness is considered as problem input.
So they should be included in the input yaml file. However, in the yaml file, user should
set `rowMajowA` and `rowMajorB` as shown in the example below.
- `--benchmark` is used to control if the perf config in the input yaml file is used as the tuning space.
Expand All @@ -218,7 +218,7 @@ This is necessary to keep each file "small" in terms of execution time.
- In benchmark mode, the kernel is executed 1000 times.
- In tuning mode, each kernel is executed 200 times. We cannot afford to larger runs since rocprof hangs if the session takes too long.
- In both tuning and benchmark mode, kernel time is measured as the average execution time of the last 100 instances.
- Added error recovery. This helps when rocprof crashes in multi-processing mode.
- Added error recovery. This helps when rocprof crashes in multi-processing mode.



Expand All @@ -233,12 +233,12 @@ This is necessary to keep each file "small" in terms of execution time.

### API changes

- Added `--rotating_tensor <value>` to use rotating memory blocks in each iteration, size in MB. Default is 0MB.
- Added `--icache_flush` to flush icache in each iteration.
- Added `--rotating_tensor <value>` to use rotating memory blocks in each iteration, size in MB. Default is 0MB.
- Added `--icache_flush` to flush icache in each iteration.
Note, icache flush needs the module `python-hip`, which can be installed as:
`python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version`
Rotating tensor and icache flush are to make perf numbers are closer to that in real applications.
- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix,
- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix,
so each element of the bias vector is added to all elements of the corresponding row of the output matrix.)


Expand Down Expand Up @@ -283,11 +283,11 @@ that cannot divide `K`.
- Tuning result file is open and closed inside the tuning loop, enabling timely flush
of the tuning results.
- Now we use `rocprofv2` to measure kernel time.
- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation
- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation
stage. This is achieved by modifying the triton frontend compiler in the following
places:
- Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433)
- Return statically constructed GPUTarget from the `get_current_target()`
- Return statically constructed GPUTarget from the `get_current_target()`
function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437)
- Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383)
- Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589)
Expand Down
28 changes: 8 additions & 20 deletions python/perf-kernels/tune_gemm/icache_flush.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import ctypes
import array
import random
import math

# the hip module can be installed as
# `python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version`
# more information about hip-python is at: https://github.com/ROCm/hip-python
from hip import hip, hiprtc


def hip_check(call_result):
err = call_result[0]
result = call_result[1:]
Expand All @@ -16,14 +12,12 @@ def hip_check(call_result):

if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
raise RuntimeError(str(err))
elif (
isinstance(err, hiprtc.hiprtcResult)
and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS
):
elif (isinstance(err, hiprtc.hiprtcResult) and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS):
raise RuntimeError(str(err))

return result


# S_ICACHE_INV Invalidate entire first level instruction cache.
# There must be 16 separate S_NOP instructions or a jump/branch instruction
# after this instruction to ensure the internal instruction buffers are also
Expand Down Expand Up @@ -56,7 +50,7 @@ def gen_kernel():
progs = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(progs, 0))
arch = progs.gcnArchName
cflags = [b"--offload-arch="+arch]
cflags = [b"--offload-arch=" + arch]
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
Expand All @@ -73,22 +67,16 @@ def gen_kernel():

return kernel


kernel = gen_kernel()
progs = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(progs, 0))
cu_num = progs.multiProcessorCount


def icache_flush():
block = hip.dim3(x=64)
grid = hip.dim3(cu_num * 60)

hip_check(hip.hipModuleLaunchKernel(
kernel,
*grid,
*block,
sharedMemBytes=0,
stream=None,
kernelParams=None,
extra=()
)
)
hip_check(
hip.hipModuleLaunchKernel(kernel, *grid, *block, sharedMemBytes=0, stream=None, kernelParams=None, extra=()))
Loading

0 comments on commit 4f75b0f

Please sign in to comment.