Skip to content

Commit

Permalink
Allow GPTQModel to auto select Marlin or faster kernels for inference…
Browse files Browse the repository at this point in the history
… only ops (#2138)

* select quant_linear with pack

* up GPTQMODEL_MINIMUM_VERSION

* Update quantizer.py

* update gptqmodel version

---------

Co-authored-by: Qubitium-ModelCloud <[email protected]>
  • Loading branch information
LRL-ModelCloud and Qubitium authored Jan 8, 2025
1 parent 72498dd commit 53240c3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
)
self.exllama_version = self.exllama_config["version"]

def select_quant_linear(self, device_map: Union[str, dict]):
def select_quant_linear(self, device_map: Union[str, dict], pack: bool = False):
if is_gptqmodel_available():
self.quant_linear = hf_select_quant_linear(
bits=self.bits,
Expand All @@ -231,6 +231,7 @@ def select_quant_linear(self, device_map: Union[str, dict]):
meta=self.meta,
device_map=device_map,
backend=self.backend,
pack=pack,
)
else:
self.quant_linear = hf_select_quant_linear(
Expand Down Expand Up @@ -301,7 +302,7 @@ def convert_model(self, model: nn.Module, **kwargs):
)
del layers_to_be_replaced[name]

self.select_quant_linear(device_map=kwargs.get("device_map", None))
self.select_quant_linear(device_map=kwargs.get("device_map", None), pack=False)

self._replace_by_quant_layers(model, layers_to_be_replaced)

Expand Down Expand Up @@ -761,7 +762,7 @@ def pack_model(
layers = get_layers(model)
layers = {n: layers[n] for n in quantizers}

self.select_quant_linear(device_map=model.hf_device_map)
self.select_quant_linear(device_map=model.hf_device_map, pack=True)

self._replace_by_quant_layers(model, quantizers)
qlayers = get_layers(model, [self.quant_linear])
Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2")
GPTQMODEL_MINIMUM_VERSION = version.parse("1.6.0")


# This is the minimal required version to support some ONNX Runtime features
Expand Down

0 comments on commit 53240c3

Please sign in to comment.