Skip to content

Commit

Permalink
Add additional matching logic to MMGroupQuantRewriterPass
Browse files Browse the repository at this point in the history
  • Loading branch information
Max191 committed Feb 22, 2024
1 parent fabd52c commit 1f1f61f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 4 deletions.
48 changes: 45 additions & 3 deletions core/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, globals: GlobalsDict, builder: Builder):

def match(self, op: Operation):
weight_transpose = Transpose2DMatcher()(op.operands[1])
if not weight_transpose:
weight_transpose = PermuteMatcher([1, 0])(op.operands[1])
if not weight_transpose:
return None
weight_load = GlobalLoadMatcher(self.globals)(weight_transpose.input)
Expand All @@ -67,6 +69,38 @@ def match(self, op: Operation):
)


class ViewTransposedMMMatcher(NamedOpMatcher):
def __init__(self, globals: GlobalsDict, builder: Builder):
super().__init__("torch.aten.mm")
self.globals = globals
self.builder = builder

def match(self, op: Operation):
weight_transpose = Transpose2DMatcher()(op.operands[1])
if not weight_transpose:
weight_transpose = PermuteMatcher([1, 0])(op.operands[1])
if not weight_transpose:
return None
weight_view = Transposed2DViewMatcher(self.builder)(weight_transpose.input)
if not weight_view:
return None
weight_load = GlobalLoadMatcher(self.globals)(weight_view.input)
if not weight_load or not weight_load.resolved_global:
return None

m, n = self.builder.get_tensor_dims(op.operands[0].type)
_, k = self.builder.get_tensor_dims(op.operands[1].type)
return TransposedMMResult(
op,
weight_global=weight_load.resolved_global,
param_name=weight_load.global_ref,
m=m,
n=n,
k=k,
element_type=self.builder.get_tensor_element_type(op.operands[0].type),
)


# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization
GROUP_MATMUL_TEMPLATE = r"""
module {{
Expand Down Expand Up @@ -125,29 +159,37 @@ def match(self, op: Operation):


class MMGroupQuantRewriterPass(Pass):
def __init__(self, root_op: Operation, *, group_size: int = 128):
def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Optional[set] = None):
super().__init__(root_op)
self.group_size = group_size
self.context = root_op.context
self.param_names = param_names

def run(self):
globals = self.globals
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))
view_mms = match_children(self.funcs, ViewTransposedMMMatcher(globals, self.builder))

for mr in mms:
if mr.k is None or mr.n is None:
continue
if (mr.k % self.group_size) != 0:
continue
self.rewrite(mr)

for mr in view_mms:
if mr.k is None or mr.n is None:
continue
if (mr.k % self.group_size) != 0 or (mr.n % self.group_size):
continue
self.rewrite(mr)

self.inline()
self.cleanup()

def rewrite(self, mr: TransposedMMResult):
none_to_q = lambda x: "?" if x is None else x
# TODO (ian): make generalizable and not specific for brevitas
if "lm_head.weight" not in mr.param_name:
if self.param_names is None or mr.param_name[8:] in self.param_names:
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
# TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass
param_name=mr.param_name[8:],
Expand Down
51 changes: 51 additions & 0 deletions core/shark_turbine/transforms/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"OpMatchResult",
"Pass",
"Transpose2DMatcher",
"PermuteMatcher",
"Transposed2DViewMatcher",
"match_children",
"pass_main",
]
Expand Down Expand Up @@ -194,6 +196,55 @@ def match(self, op: Operation) -> Optional[Transpose2DResult]:
return result


class PermuteResult(OpMatchResult):
@property
def input(self) -> Value:
return self.op.operands[0]


class PermuteMatcher(NamedOpMatcher):
def __init__(self, permutation: list[int]):
super().__init__("torch.aten.permute")
self.permutation = permutation

def match(self, op: Operation) -> Optional[PermuteResult]:
list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1])
if not list_construct:
return None
list_construct = list_construct.op
if len(self.permutation) != len(list_construct.operands):
return None
for i, list_item in enumerate(list_construct.operands):
if not ConstantIntMatcher(self.permutation[i])(list_item):
return None
return PermuteResult(op)


class Transposed2DViewResult(OpMatchResult):
@property
def input(self) -> Value:
return self.op.operands[0]


class Transposed2DViewMatcher(NamedOpMatcher):
def __init__(self, builder: Builder):
super().__init__("torch.aten.view")
self.builder = builder

def match(self, op: Operation) -> Optional[Transposed2DViewResult]:
list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1])
if not list_construct:
return None
list_construct = list_construct.op
if len(list_construct.operands) != 2:
return None
tensor_dims = self.builder.get_tensor_dims(op.operands[0].type)
if not ConstantIntMatcher(tensor_dims[0])(list_construct.operands[1]) or \
not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]):
return None
return Transposed2DViewResult(op)


class ConstantIntMatcher(NamedOpMatcher):
def __init__(self, value: int):
super().__init__("torch.constant.int")
Expand Down
8 changes: 7 additions & 1 deletion models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ def export_transformer_model(
)

mapper = {}
param_set = set()
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params." + name] = name
param_set.add(name)
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

Expand Down Expand Up @@ -313,8 +315,12 @@ def evict_kvcache_space(self):
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

print(param_set)
if "lm_head.weight" in param_set:
param_set.remove("lm_head.weight")
mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
CompiledModule.get_mlir_module(inst).operation,
param_names=param_set,
).run()
module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
Expand Down

0 comments on commit 1f1f61f

Please sign in to comment.