Skip to content

Commit

Permalink
Adapt transformers 4.45.1 (#2019)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
Co-authored-by: changwangss <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 30, 2024
1 parent d4662ad commit 44795a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
3 changes: 3 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,9 @@ def block_inference(self, model):
"""
total_out = []
for args, kwargs in zip(self.total_block_args, self.total_block_kwargs):
# to avoid layer_past: Dynamic_cache when transformers higher than 4.45.1
if "layer_past" in kwargs.keys() and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
out = model(*args, **kwargs)
if isinstance(out, tuple): # pragma: no cover
out = out[0]
Expand Down
43 changes: 30 additions & 13 deletions neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,19 +834,36 @@ def _load_remaining_pretrained_weight(self, model):
resolved_archive_file = [resolved_archive_file]
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
_load_state_dict_into_meta_model(
model=model,
state_dict=state_dict,
loaded_state_dict_keys=self.loaded_state_dict_keys,
start_prefix="",
expected_keys=list(state_dict.keys()),
device_map={"": self.device},
offload_folder=offload_folder,
state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None,
state_dict_index={} if offload_state_dict else None,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
import transformers
from packaging.version import Version

if Version(transformers.__version__) >= Version("4.45.0"): # pragma: no cover
_load_state_dict_into_meta_model(
model=model,
state_dict=state_dict,
start_prefix="",
expected_keys=list(state_dict.keys()),
device_map={"": self.device},
offload_folder=offload_folder,
state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None,
state_dict_index={} if offload_state_dict else None,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
else:
_load_state_dict_into_meta_model(
model=model,
state_dict=state_dict,
loaded_state_dict_keys=self.loaded_state_dict_keys,
start_prefix="",
expected_keys=list(state_dict.keys()),
device_map={"": self.device},
offload_folder=offload_folder,
state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None,
state_dict_index={} if offload_state_dict else None,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down

0 comments on commit 44795a1

Please sign in to comment.