Skip to content

Commit

Permalink
Rename kpack to kWidth and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Aug 7, 2024
1 parent 701d06f commit 058c7c7
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions scripts/amd/plot_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def parse_args():
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4))
parser.add_argument("-order", type=int, nargs=2, default=(1, 0))
## LDS access parameters
parser.add_argument("-kpack",
parser.add_argument("-kWidth",
type=int,
default=4,
choices=[4, 8],
help='vector length during LDS load, same as vec')
choices=[4, 8, 16],
help='number of elements per thread')
parser.add_argument("-lds_layout",
type=str,
default="none",
Expand All @@ -224,7 +224,7 @@ def parse_args():
action='store_true',
default=False,
help='If set, then use mfma.trans layout')
parser.add_argument("--keep",
parser.add_argument("-keep",
action='store_true',
default=False,
help='If set, keep the generated .tex file')
Expand All @@ -243,7 +243,7 @@ def main():
K = shape[2]
plot_mode = args.plot
mfmaNonKDim = args.nonKDim
kpack = args.kpack
kpack = args.kWidth
trans = 1 if args.mfmaTrans else 0
ofilename = args.o
keepSrc = args.keep
Expand All @@ -269,13 +269,13 @@ def main():
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1])

if plot_mode == 'dot':
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16"
mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16"
mfma_trans_str = ".trans" if trans else ""
print(f"Plotting dot operation with shapes M={M},N={N},K={K}")
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ")
print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
CTAShape.append(32 * warpsPerCTA[0])
CTAShape.append(32 * warpsPerCTA[1])
CTAShape.append(mfmaNonKDim * warpsPerCTA[0])
CTAShape.append(mfmaNonKDim * warpsPerCTA[1])

if plot_mode == 'blocked' or plot_mode == 'dot':
print(f"CTAShape={CTAShape}")
Expand Down

0 comments on commit 058c7c7

Please sign in to comment.