Skip to content

Commit

Permalink
Merge branch 'master' into uly-hf
Browse files Browse the repository at this point in the history
  • Loading branch information
samadejacobs authored Aug 20, 2024
2 parents cb7c20e + 96393f5 commit 513d479
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

class HuggingFaceCheckpointEngine(CheckpointEngineBase):

def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
def __init__(self, model_name_or_path: str, auth_token: str = None, **hf_kwargs) -> None:
super().__init__()
from transformers import AutoConfig, GenerationConfig

self.model_name_or_path = model_name_or_path
self.auth_token = auth_token
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path, **hf_kwargs)
# Define this property here so we can use it in the model implementation
if not hasattr(self.model_config, "max_seq_length"):
if hasattr(self.model_config, "max_position_embeddings"):
Expand Down
11 changes: 9 additions & 2 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#when implemented outside of torch.autograd.Function

import math
import functools

import torch
from torch import Tensor
Expand All @@ -33,8 +34,14 @@ def print_rank_0(message, debug=False, force=False):


try:
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
# Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4
if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'):
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
else:
# original implementation
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator
Expand Down

0 comments on commit 513d479

Please sign in to comment.