diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 46a84c61f884..77dfa7a23b1e 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -15,13 +15,13 @@ class HuggingFaceCheckpointEngine(CheckpointEngineBase): - def __init__(self, model_name_or_path: str, auth_token: str = None) -> None: + def __init__(self, model_name_or_path: str, auth_token: str = None, **hf_kwargs) -> None: super().__init__() from transformers import AutoConfig, GenerationConfig self.model_name_or_path = model_name_or_path self.auth_token = auth_token - self.model_config = AutoConfig.from_pretrained(self.model_name_or_path) + self.model_config = AutoConfig.from_pretrained(self.model_name_or_path, **hf_kwargs) # Define this property here so we can use it in the model implementation if not hasattr(self.model_config, "max_seq_length"): if hasattr(self.model_config, "max_position_embeddings"): diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index e9dd78864cde..8c8db60768eb 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -16,6 +16,7 @@ #when implemented outside of torch.autograd.Function import math +import functools import torch from torch import Tensor @@ -33,8 +34,14 @@ def print_rank_0(message, debug=False, force=False): try: - autocast_custom_fwd = get_accelerator().amp().custom_fwd - autocast_custom_bwd = get_accelerator().amp().custom_bwd + # Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4 + if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'): + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) + else: + # original implementation + autocast_custom_fwd = get_accelerator().amp().custom_fwd + autocast_custom_bwd = get_accelerator().amp().custom_bwd except (ImportError, AttributeError) as exp: autocast_custom_fwd = noop_decorator autocast_custom_bwd = noop_decorator