Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Max191 committed Feb 22, 2024
1 parent 1f1f61f commit 373c068
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
14 changes: 11 additions & 3 deletions core/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ def match(self, op: Operation):


class MMGroupQuantRewriterPass(Pass):
def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Optional[set] = None):
def __init__(
self,
root_op: Operation,
*,
group_size: int = 128,
param_names: Optional[set] = None,
):
super().__init__(root_op)
self.group_size = group_size
self.context = root_op.context
Expand All @@ -168,15 +174,17 @@ def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Op
def run(self):
globals = self.globals
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))
view_mms = match_children(self.funcs, ViewTransposedMMMatcher(globals, self.builder))
view_mms = match_children(
self.funcs, ViewTransposedMMMatcher(globals, self.builder)
)

for mr in mms:
if mr.k is None or mr.n is None:
continue
if (mr.k % self.group_size) != 0:
continue
self.rewrite(mr)

for mr in view_mms:
if mr.k is None or mr.n is None:
continue
Expand Down
5 changes: 3 additions & 2 deletions core/shark_turbine/transforms/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def match(self, op: Operation) -> Optional[Transposed2DViewResult]:
if len(list_construct.operands) != 2:
return None
tensor_dims = self.builder.get_tensor_dims(op.operands[0].type)
if not ConstantIntMatcher(tensor_dims[0])(list_construct.operands[1]) or \
not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]):
if not ConstantIntMatcher(tensor_dims[0])(
list_construct.operands[1]
) or not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]):
return None
return Transposed2DViewResult(op)

Expand Down

0 comments on commit 373c068

Please sign in to comment.