Skip to content

Commit

Permalink
mixtral suport
Browse files Browse the repository at this point in the history
  • Loading branch information
Vahe1994 committed Jan 17, 2024
1 parent 81f69a8 commit d6561e4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):
sequential = get_sequential_groups(model)
else:
sequential = [list(find_sublayers(layer).keys())]

for names in sequential:
if len(args.devices) == 1:
assert len(inps) == len(outs) == 1 # number of per-device inputs/outputs
aq_handlers = init_aq_engines(layer, names, inps[0], outs[0], **forward_args)
aq_handlers = init_aq_engines(
layer, [name for name in names if "gate" not in name], inps[0], outs[0], **forward_args
)
else:
aq_handlers = init_aq_engines_parallel(args.devices, layer, names, inps, outs, **forward_args)
aq_handlers = init_aq_engines_parallel(
args.devices, layer, [name for name in names if "gate" not in name], inps, outs, **forward_args
)

for sublayer_name in aq_handlers.keys():
print(f"Quantizing module {sublayer_name} of layer {layer_index}")
Expand Down
3 changes: 2 additions & 1 deletion src/modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

MODEL_ERROR_MSG = "Unsupported model type {} - only 'llama', 'Yi', 'opt' and 'falcon' are supported"
FALCON_TYPES = ("falcon", "refinedweb", "refinedwebmodel")
LLAMA_LIKE = ("llama", "Yi", "mistral")
LLAMA_LIKE = ("llama", "Yi", "mistral", "mixtral")


@contextmanager
Expand Down Expand Up @@ -119,6 +119,7 @@ def find_sublayers(module, layers=(nn.Conv2d, nn.Linear)):

def get_sequential_groups(model):
if model.config.model_type in LLAMA_LIKE:
assert "mixtral" in model.config.model_type
return [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
Expand Down

0 comments on commit d6561e4

Please sign in to comment.