From 44795a1ae93f3676a595063cf0e6f680c41989b2 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 30 Sep 2024 16:55:22 +0800 Subject: [PATCH] Adapt transformers 4.45.1 (#2019) Signed-off-by: Kaihui-intel Co-authored-by: changwangss Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../torch/algorithms/weight_only/awq.py | 3 ++ .../torch/algorithms/weight_only/save_load.py | 43 +++++++++++++------ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 00d7fb5172c..677f3cb9899 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -516,6 +516,9 @@ def block_inference(self, model): """ total_out = [] for args, kwargs in zip(self.total_block_args, self.total_block_kwargs): + # to avoid layer_past: Dynamic_cache when transformers higher than 4.45.1 + if "layer_past" in kwargs.keys() and kwargs["layer_past"] is not None: + kwargs["layer_past"] = None out = model(*args, **kwargs) if isinstance(out, tuple): # pragma: no cover out = out[0] diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 8d1259cad00..7d22c7efbc9 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -834,19 +834,36 @@ def _load_remaining_pretrained_weight(self, model): resolved_archive_file = [resolved_archive_file] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - _load_state_dict_into_meta_model( - model=model, - state_dict=state_dict, - loaded_state_dict_keys=self.loaded_state_dict_keys, - start_prefix="", - expected_keys=list(state_dict.keys()), - device_map={"": self.device}, - offload_folder=offload_folder, - state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, - state_dict_index={} if offload_state_dict else None, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) + import transformers + from packaging.version import Version + + if Version(transformers.__version__) >= Version("4.45.0"): # pragma: no cover + _load_state_dict_into_meta_model( + model=model, + state_dict=state_dict, + start_prefix="", + expected_keys=list(state_dict.keys()), + device_map={"": self.device}, + offload_folder=offload_folder, + state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, + state_dict_index={} if offload_state_dict else None, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + else: + _load_state_dict_into_meta_model( + model=model, + state_dict=state_dict, + loaded_state_dict_keys=self.loaded_state_dict_keys, + start_prefix="", + expected_keys=list(state_dict.keys()), + device_map={"": self.device}, + offload_folder=offload_folder, + state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, + state_dict_index={} if offload_state_dict else None, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) # make sure token embedding weights are still tied if needed model.tie_weights()