Skip to content

Commit

Permalink
add large-k shapes to addmm
Browse files Browse the repository at this point in the history
Summary: Add shapes from the diff linked in T186231930

Reviewed By: xuzhao9

Differential Revision: D60489313

fbshipit-source-id: c49483dc164affc0c383966d164fd4c9a5239807
  • Loading branch information
karthik-man authored and facebook-github-bot committed Aug 1, 2024
1 parent d236f4c commit f4ed185
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchbenchmark/operators/addmm/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ def parse_args(args: List[str]) -> argparse.Namespace:
parser.add_argument("--n", type=int)
parser.add_argument("--input", type=str)
parser.add_argument("--col-major", type=bool, default=False)
parser.add_argument("--large-k-shapes", type=bool, default=False)
args = parser.parse_args(args)
return args
7 changes: 6 additions & 1 deletion torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import argparse
import itertools
import os
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
Expand Down Expand Up @@ -64,6 +65,8 @@
(20068, 1536, 512),
]

# M=13, K=2^6..2^25, N=2
LARGE_K_SHAPES = list(itertools.product([13], [2**i for i in range(6, 26)], [2]))

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "best_config"]
Expand All @@ -74,6 +77,8 @@ def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
elif addmm_args.large_k_shapes:
self.shapes = LARGE_K_SHAPES
else:
self.shapes = BUILDIN_SHAPES
self.col_major = addmm_args.col_major
Expand Down

0 comments on commit f4ed185

Please sign in to comment.