Skip to content

Commit

Permalink
fix xpu device set weight and bias (#2010)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
Co-authored-by: Sun, Xuehao <[email protected]>
  • Loading branch information
changwangss and XuehaoSun authored Sep 27, 2024
1 parent 9d27743 commit 72398b6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
4 changes: 1 addition & 3 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!"
quantization_config.update(**{"device": "xpu"})
quantization_config.post_init_xpu()
if (
not torch.cuda.is_available() or device_map == "cpu" or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
if (device_map == "cpu" or device_map == torch.device("cpu")) and model.config.model_type == "chatglm":
model = model.float()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
if isinstance(quantization_config, AwqConfig):
Expand Down
28 changes: 13 additions & 15 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,30 +223,28 @@ def _replace_linear(
module.qzeros if hasattr(module, "qzeros") else None,
g_idx,
)
if not hasattr(module, "qweight"):
n_pack = 32 // quantization_config.bits

weight = torch.zeros(
(math.ceil(in_features / n_pack), out_features),
dtype=torch.int32,
device=torch.device(device),
)
model._modules[name].set_weights_bias(
module.qweight.data if hasattr(module, "qweight") else weight,
None if module.bias is None else module.bias.data,
)
else:
raise Exception("{} device Unsupported weight only quantization!".format(device))

is_replaced = True
is_removed = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)

if device == "xpu" or device == torch.device("xpu"):
if not hasattr(module, "qweight"):
n_pack = 32 // quantization_config.bits

weight = torch.zeros(
(math.ceil(in_features / n_pack), out_features),
dtype=torch.int32,
device=torch.device(device),
)
model._modules[name].set_weights_bias(
module.qweight.data if hasattr(module, "qweight") else weight,
None if module.bias is None else module.bias.data,
)
is_removed = True

if not is_removed and len(list(module.children())) > 0: # pylint: disable=E1101
_, is_replaced = _replace_linear(
module,
Expand Down

0 comments on commit 72398b6

Please sign in to comment.