Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[plot tool] Add mfma16 support #624

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions scripts/amd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Plot script for triton layouts

This script is used to draw triton layouts in the context of matmul.
Here is the help info from the script.

```bash
>$ python3 plot_layout.py -h
usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP]
[-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep]

options:
-h, --help show this help message and exit
-shape SHAPE SHAPE SHAPE
Tensor shape in the form of M,N,K
-plot {blocked,dot,wmma,lds}
choose plot mode
-nonKDim {16,32} mfma instruction dim, only 32 is supported for now
-sizePerThread SIZEPERTHREAD SIZEPERTHREAD
-threadsPerWarp THREADSPERWARP THREADSPERWARP
-warpsPerCTA WARPSPERCTA WARPSPERCTA
-order ORDER ORDER
-kWidth {4,8,16} number of elements per thread
-lds_layout {swizzle,padding,none}
choose the LDS data layout
-lds_access {read,write,none}
choose LDS access mode
-wave_size {32,64} choose the wmma instruction mode
-o O output pdf file name (without surfix)
-mfmaTrans If set, then use mfma.trans layout
-keep If set, keep the generated .tex file
```

## Installation
This script does not require torch or triton to be installed. The only package
it depends on is latex. On Ubuntu, do
```bash
sudo apt install texlive-full
```

## Draw blocked layout (`-plot blocked`)

Examples:
```bash
python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
```

Blocked layouts are used during global load. It is used to describe the layout of the tensor
for pointers and results.
We can provide tensor shape (`-shape M N K`) and blocked layout parameters (
`-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`).
We can also provide the order of the tensor as `-order x y` to control which dim
is the fastest changing dimension.

Notes
- All of the gemm dims (M, N, and K) are needed when providing the shape. But only
M and K will be used to plot the layout of the tensor.
- The script does not support the case when threads are loading elements that are
out of the boundary of the tensor dimensions. This means
- For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M
- For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K


## Draw mfma operand and result layouts (`-plot dot`)

Examples:
```bash
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16
```

This mode draws two graphs:
1. The layout of the whole tile for tile A, B, and C
2. The layout of a single mfma block, operands and results of one or more mfma
instructions that share the same accumulating VGPRs.
This view has thread distributions among tensor elements.

Knobs
- `-kWidth`: the number of elements that will be loaded into one thread at once
- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size
- `-mfmaTrans`: if set, the transposed mfma layout will be plotted.

Notes
- The layout shows the mapping from the threads/wave to the elements in the
original tensor. It does not care if the elements are arranged in LDS, like
swizzling to avoid bank conflicts.
- The script does not allow settings for data type or k dim of the mfma instruction.
This can be controled by the `-kWidth` flag.
- For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`.
- If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`.


## Draw LDS access (`-plot lds`)

Examples:
```bash
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8
```

Knobs
- `kWidth` here means the vector size when accessing LDS
- Three options for `-lds_layout`:
- `none`: no swizzling, no padding
- `padding`: padding at every 128B
- `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth.
- Three options for `-lds_access`:
- `none`: do not plot access pattern
- `read`: plot accessed elements during ds_read
- `write`: plot accessed elements during ds_write. Note that this needs some infomation from
global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`.

Notes
- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly.
35 changes: 13 additions & 22 deletions scripts/amd/plot_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,13 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack):

\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$);
\\def\\mfmaTrans{{{trans}}}
\\ifthenelse{{\\mfmaTrans=0}}{{
\\def\\opColorAL{{magenta}}
\\def\\opColorAR{{cyan}}
\\def\\opColorBL{{Maroon}}
\\def\\opColorBR{{BlueGreen}}
}}{{
\\def\\opColorBL{{magenta}}
\\def\\opColorBR{{cyan}}
\\def\\opColorAL{{Maroon}}
\\def\\opColorAR{{BlueGreen}}
}}

%% Draw zoomed in view of mfma
\\def\\elem{{.16}}
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}}
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}}
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$);
\\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}}
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kpack}*\\elem, 0)$);
\\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}}

\\end{{tikzpicture}}
Expand Down Expand Up @@ -195,19 +186,19 @@ def parse_args():
"-nonKDim",
type=int,
default=32,
choices=[32],
choices=[16, 32],
help='mfma instruction dim, only 32 is supported for now')
## blocked layout parameters
parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4))
parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4))
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4))
parser.add_argument("-order", type=int, nargs=2, default=(1, 0))
## LDS access parameters
parser.add_argument("-kpack",
parser.add_argument("-kWidth",
type=int,
default=4,
choices=[4, 8],
help='vector length during LDS load, same as vec')
choices=[4, 8, 16],
help='number of elements per thread')
parser.add_argument("-lds_layout",
type=str,
default="none",
Expand All @@ -233,7 +224,7 @@ def parse_args():
action='store_true',
default=False,
help='If set, then use mfma.trans layout')
parser.add_argument("--keep",
parser.add_argument("-keep",
action='store_true',
default=False,
help='If set, keep the generated .tex file')
Expand All @@ -252,7 +243,7 @@ def main():
K = shape[2]
plot_mode = args.plot
mfmaNonKDim = args.nonKDim
kpack = args.kpack
kpack = args.kWidth
trans = 1 if args.mfmaTrans else 0
ofilename = args.o
keepSrc = args.keep
Expand All @@ -278,13 +269,13 @@ def main():
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1])

if plot_mode == 'dot':
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16"
mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16"
mfma_trans_str = ".trans" if trans else ""
print(f"Plotting dot operation with shapes M={M},N={N},K={K}")
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ")
print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
CTAShape.append(32 * warpsPerCTA[0])
CTAShape.append(32 * warpsPerCTA[1])
CTAShape.append(mfmaNonKDim * warpsPerCTA[0])
CTAShape.append(mfmaNonKDim * warpsPerCTA[1])

if plot_mode == 'blocked' or plot_mode == 'dot':
print(f"CTAShape={CTAShape}")
Expand Down
Loading
Loading