From 701d06f32f997c7a4a7614aa7c57fef830879b60 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 6 Aug 2024 14:38:31 -0500 Subject: [PATCH 1/3] Support mfma16 instructions --- scripts/amd/plot_layout.py | 17 ++----- scripts/amd/tikzplot.tex | 94 ++++++++++++++++++++------------------ 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/scripts/amd/plot_layout.py b/scripts/amd/plot_layout.py index f801a7fd2898..c6a9942886c9 100755 --- a/scripts/amd/plot_layout.py +++ b/scripts/amd/plot_layout.py @@ -82,22 +82,13 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); \\def\\mfmaTrans{{{trans}}} - \\ifthenelse{{\\mfmaTrans=0}}{{ - \\def\\opColorAL{{magenta}} - \\def\\opColorAR{{cyan}} - \\def\\opColorBL{{Maroon}} - \\def\\opColorBR{{BlueGreen}} - }}{{ - \\def\\opColorBL{{magenta}} - \\def\\opColorBR{{cyan}} - \\def\\opColorAL{{Maroon}} - \\def\\opColorAR{{BlueGreen}} - }} + %% Draw zoomed in view of mfma \\def\\elem{{.16}} \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$); + \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kpack}*\\elem, 0)$); \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}} \\end{{tikzpicture}} @@ -195,7 +186,7 @@ def parse_args(): "-nonKDim", type=int, default=32, - choices=[32], + choices=[16, 32], help='mfma instruction dim, only 32 is supported for now') ## blocked layout parameters parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4)) diff --git a/scripts/amd/tikzplot.tex b/scripts/amd/tikzplot.tex index 5f5ff6674b19..e6292f7002e9 100755 --- a/scripts/amd/tikzplot.tex +++ b/scripts/amd/tikzplot.tex @@ -154,30 +154,38 @@ \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); } -\newcommand{\drawBlockMFMALayoutLarge}[2]{ +\newcommand{\drawBlockMFMALayoutLarge}[3]{ %% - %% Draw a single block of MFMA_32x32x8xf16 + %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 %% %% block TL: pre-defined top-left coordinate of the block %% \elem: pre defined variable %% %% #1: 1 for mfma.trans, 0 for normal mfma - %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #2: mfmaNonKDim + %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing \pgfmathsetmacro{\trans}{#1} \pgfmathsetmacro{\nonTrans}{1-#1} - \pgfmathsetmacro{\verbose}{#2} - \foreach \iVec in {0,1,2,3} { - \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*2*4*\elem, -\nonTrans*\iVec*2*4*\elem)$); - \foreach \col/\tg in {blue/0,orange/1}{ - \foreach \tid in {0,...,31} { - \pgfmathsetmacro{\ratio}{\tid*2.5+15} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\maxGID}{\groups-1} + \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} + \pgfmathsetmacro{\verbose}{#3} + \foreach \iVec in {0,...,\maxIVec} { + \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); + \foreach \tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\colID}{\tg+4} + \pgfmathsetmacro{\col}{\Colors[\colID]} + \foreach \tid in {0,...,\maxTID} { + \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} \ifthenelse{\verbose=0}{ \draw [line width=0.005mm, fill=\col!\ratio!white] ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)} + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} \draw [line width=0.005mm, fill=\col!\ratio!white] ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) @@ -186,7 +194,7 @@ } } } - \draw [thick] (block TL) rectangle ++(32*\elem, -32*\elem); + \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); } @@ -228,7 +236,7 @@ \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); %% Draw a detailed view of wave0 in each CTA \coordinate (block TL) at (CTA TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{0} + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} \foreach \waveId in {0,...,\maxWaveId}{ \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} @@ -236,7 +244,7 @@ \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); %% Inside the loop, only draw a rectangle \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) - node [scale=.7*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; + node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; } %% Draw the outline of each CTA rep @@ -259,28 +267,23 @@ %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing \pgfmathsetmacro{\nonKDim}{#1} + \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} \pgfmathsetmacro{\kpack}{#2} \pgfmathsetmacro{\opIdxA}{#3} \pgfmathsetmacro{\opIdxB}{1-\opIdxA} \pgfmathsetmacro{\verbose}{#4} - \ifthenelse{\opIdxA = 0}{ - \def\opColorL{\opColorAL} - \def\opColorR{\opColorAR} - }{ - \def\opColorL{\opColorBL} - \def\opColorR{\opColorBR} - } - - \foreach \col/\tg in {\opColorL/0,\opColorR/1}{ - \foreach \tid in {0,...,31} { + \foreach \col/\tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\col}{\Colors[\tg]} + \foreach \tid in {0,...,\maxTID} { % \pgfmathsetmacro{\ratio}{\tid*2.5+15} \ifthenelse{\verbose=0}{ \draw [line width=0.005mm, fill=\col] ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)} + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} \draw [line width=0.005mm, fill=\col] ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) @@ -304,20 +307,21 @@ \pgfmathsetmacro{\K}{#1} \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\groups}{64/\nonKDim} \pgfmathsetmacro{\kpack}{#3} \pgfmathsetmacro{\opIdx}{#4} \pgfmathsetmacro{\opIdxOther}{1-\opIdx} \coordinate (TL) at (Op TL); - \pgfmathsetmacro{\numKRep}{\K/\kpack/2} + \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} \pgfmathsetmacro{\maxKRepId}{\numKRep-1} \foreach \repId in {0,...,\maxKRepId}{ - \coordinate (mfma op TL) at ($(TL)+(\repId*2*\kpack*\elem*\opIdxOther, -\repId*2*\kpack*\elem*\opIdx)$); + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} \draw [thick] (mfma op TL) rectangle - ++(2*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-2*\kpack*\elem*\opIdx); + ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); } } @@ -345,41 +349,41 @@ \pgfmathsetmacro{\kpack}{#7} %% operand A - \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/32} + \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} \foreach \ctaId in {0,...,\maxCTAIdM}{ - \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*32*\elem)$); + \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*32*\elem)$); - \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -32*\elem); + \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); + \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*32*\elem); + \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); } \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); %% operand B - \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/32} + \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} \foreach \ctaId in {0,...,\maxCTAIdN}{ - \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*32*\elem, 0)$); + \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(\waveId*32*\elem ,0)$); - \draw [ultra thin] (wave TL) rectangle ++(32*\elem, -\K*\elem); + \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); + \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*32*\elem, -\K*\elem); + \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); } \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); } @@ -405,11 +409,12 @@ \pgfmathsetmacro{\N}{#2} \pgfmathsetmacro{\K}{#3} \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\warpsPerCTAM}{#5} \pgfmathsetmacro{\warpsPerCTAN}{#6} \pgfmathsetmacro{\mfmaTrans}{#7} \pgfmathsetmacro{\kpack}{#8} - \pgfmathsetmacro{\kdim}{int(2*\kpack)} + \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} \pgfmathsetmacro{\gap}{\elem*20} \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); @@ -434,8 +439,8 @@ \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; %% label kpack - \node [scale=.8*\scale, above] at ($(A TL)+(\kpack*\elem, 0)$) {\kdim}; - \node [scale=.8*\scale, left] at ($(B TL)+(0, -\kpack*\elem)$) {\kdim}; + \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; + \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; } \newcommand{\Colors}{{ @@ -692,19 +697,20 @@ %% #2: kpack %% #3: mfmaTrans \pgfmathsetmacro{\mfmaNonKDim}{#1} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\kpack}{#2} \pgfmathsetmacro{\mfmaTrans}{#3} \pgfmathsetmacro{\nonTrans}{1-#3} \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-2*\kpack*\elem, 0)$); + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); \coordinate (mfma op TL) at (mfma opA TL); \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+2*\kpack*\elem)$); + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} \coordinate (block TL) at (C TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{1} + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} %% Draw labels \def\vecR{1.5} @@ -736,7 +742,7 @@ \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=False}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; }{ \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; \node [scale=\scale, above] at (mfma op TL) {opA}; From 058c7c7a86740a78c236c7c7a8c1adcc972a965c Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 6 Aug 2024 15:16:17 -0500 Subject: [PATCH 2/3] Rename kpack to kWidth and refactor --- scripts/amd/plot_layout.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/amd/plot_layout.py b/scripts/amd/plot_layout.py index c6a9942886c9..325202c12c63 100755 --- a/scripts/amd/plot_layout.py +++ b/scripts/amd/plot_layout.py @@ -194,11 +194,11 @@ def parse_args(): parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) ## LDS access parameters - parser.add_argument("-kpack", + parser.add_argument("-kWidth", type=int, default=4, - choices=[4, 8], - help='vector length during LDS load, same as vec') + choices=[4, 8, 16], + help='number of elements per thread') parser.add_argument("-lds_layout", type=str, default="none", @@ -224,7 +224,7 @@ def parse_args(): action='store_true', default=False, help='If set, then use mfma.trans layout') - parser.add_argument("--keep", + parser.add_argument("-keep", action='store_true', default=False, help='If set, keep the generated .tex file') @@ -243,7 +243,7 @@ def main(): K = shape[2] plot_mode = args.plot mfmaNonKDim = args.nonKDim - kpack = args.kpack + kpack = args.kWidth trans = 1 if args.mfmaTrans else 0 ofilename = args.o keepSrc = args.keep @@ -269,13 +269,13 @@ def main(): CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) if plot_mode == 'dot': - mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16" + mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" mfma_trans_str = ".trans" if trans else "" print(f"Plotting dot operation with shapes M={M},N={N},K={K}") - print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ") + print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ") print(f"warpsPerCTA={warpsPerCTA}", end=" ") - CTAShape.append(32 * warpsPerCTA[0]) - CTAShape.append(32 * warpsPerCTA[1]) + CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) + CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) if plot_mode == 'blocked' or plot_mode == 'dot': print(f"CTAShape={CTAShape}") From 25d34122acb3965cd365dffc2911e0af0294578f Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 6 Aug 2024 21:52:39 -0500 Subject: [PATCH 3/3] Added a README for the plot script --- scripts/amd/README.md | 117 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 scripts/amd/README.md diff --git a/scripts/amd/README.md b/scripts/amd/README.md new file mode 100644 index 000000000000..26b4ef15aa8d --- /dev/null +++ b/scripts/amd/README.md @@ -0,0 +1,117 @@ +# Plot script for triton layouts + +This script is used to draw triton layouts in the context of matmul. +Here is the help info from the script. + +```bash +>$ python3 plot_layout.py -h +usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] + [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] + +options: + -h, --help show this help message and exit + -shape SHAPE SHAPE SHAPE + Tensor shape in the form of M,N,K + -plot {blocked,dot,wmma,lds} + choose plot mode + -nonKDim {16,32} mfma instruction dim, only 32 is supported for now + -sizePerThread SIZEPERTHREAD SIZEPERTHREAD + -threadsPerWarp THREADSPERWARP THREADSPERWARP + -warpsPerCTA WARPSPERCTA WARPSPERCTA + -order ORDER ORDER + -kWidth {4,8,16} number of elements per thread + -lds_layout {swizzle,padding,none} + choose the LDS data layout + -lds_access {read,write,none} + choose LDS access mode + -wave_size {32,64} choose the wmma instruction mode + -o O output pdf file name (without surfix) + -mfmaTrans If set, then use mfma.trans layout + -keep If set, keep the generated .tex file +``` + +## Installation +This script does not require torch or triton to be installed. The only package +it depends on is latex. On Ubuntu, do +```bash +sudo apt install texlive-full +``` + +## Draw blocked layout (`-plot blocked`) + +Examples: +```bash +python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 +python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 +python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 +``` + +Blocked layouts are used during global load. It is used to describe the layout of the tensor +for pointers and results. +We can provide tensor shape (`-shape M N K`) and blocked layout parameters ( +`-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`). +We can also provide the order of the tensor as `-order x y` to control which dim +is the fastest changing dimension. + +Notes +- All of the gemm dims (M, N, and K) are needed when providing the shape. But only + M and K will be used to plot the layout of the tensor. +- The script does not support the case when threads are loading elements that are + out of the boundary of the tensor dimensions. This means + - For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M + - For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K + + +## Draw mfma operand and result layouts (`-plot dot`) + +Examples: +```bash +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 +``` + +This mode draws two graphs: +1. The layout of the whole tile for tile A, B, and C +2. The layout of a single mfma block, operands and results of one or more mfma + instructions that share the same accumulating VGPRs. + This view has thread distributions among tensor elements. + +Knobs +- `-kWidth`: the number of elements that will be loaded into one thread at once +- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size +- `-mfmaTrans`: if set, the transposed mfma layout will be plotted. + +Notes +- The layout shows the mapping from the threads/wave to the elements in the + original tensor. It does not care if the elements are arranged in LDS, like + swizzling to avoid bank conflicts. +- The script does not allow settings for data type or k dim of the mfma instruction. + This can be controled by the `-kWidth` flag. + - For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`. + - If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`. + + +## Draw LDS access (`-plot lds`) + +Examples: +```bash +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8 +``` + +Knobs +- `kWidth` here means the vector size when accessing LDS +- Three options for `-lds_layout`: + - `none`: no swizzling, no padding + - `padding`: padding at every 128B + - `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth. +- Three options for `-lds_access`: + - `none`: do not plot access pattern + - `read`: plot accessed elements during ds_read + - `write`: plot accessed elements during ds_write. Note that this needs some infomation from + global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`. + +Notes +- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly.