Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Apply HF fallbacks to all from_pretrained() object initializations. #3745

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
VECTOR,
)
from ludwig.error import ConfigValidationError
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback
from ludwig.utils.metric_utils import get_feature_to_metric_names_map_from_feature_collection
from ludwig.utils.misc_utils import merge_dict

Expand Down Expand Up @@ -594,7 +595,7 @@ def check_llm_finetuning_adaption_prompt_parameters(config: "ModelConfig"):

def _get_llm_model_config(model_name: str) -> AutoConfig:
"""Returns the LLM model config."""
return AutoConfig.from_pretrained(model_name)
return load_pretrained_hf_class_with_hub_fallback(AutoConfig, model_name)[0]


# TODO(geoffrey, arnav): uncomment this when we have reconciled the config with the backend kwarg in api.py
Expand Down
3 changes: 2 additions & 1 deletion ludwig/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Stacked2DCNNConfig,
ViTConfig,
)
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback
from ludwig.utils.torch_utils import FreezeModule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -382,7 +383,7 @@ def __init__(
self._input_shape = (in_channels, img_height, img_width)

if use_pretrained and not saved_weights_in_checkpoint:
transformer = ViTModel.from_pretrained(pretrained_model)
transformer, _ = load_pretrained_hf_class_with_hub_fallback(ViTModel, pretrained_model)
else:
config = ViTConfig(
image_size=img_height,
Expand Down
40 changes: 20 additions & 20 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
XLNetConfig,
)
from ludwig.schema.llms.peft import BaseAdapterConfig
from ludwig.utils.hf_utils import load_pretrained_hf_model_with_hub_fallback
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback
from ludwig.utils.torch_utils import FreezeModule

if TYPE_CHECKING:
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(
hf_config_params = {k: v for k, v in kwargs.items() if k in schema_cls.get_hf_config_param_names()}
if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
model_cls, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -302,7 +302,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
AlbertModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -428,7 +428,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
MT5EncoderModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -524,7 +524,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
XLMRobertaModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -646,7 +646,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
BertModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -793,7 +793,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
XLMModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -900,7 +900,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
OpenAIGPTModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def __init__(

if use_pretrained:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
GPT2Model, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
RobertaModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
TransfoXLModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
XLNetModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
DistilBertModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1585,7 +1585,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
CTRLModel, pretrained_model_name_or_path, **pretrained_kwargs
)
self.vocab_size = transformer.config.vocab_size
Expand Down Expand Up @@ -1698,7 +1698,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
CamembertModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
T5Model, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -1949,7 +1949,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
FlaubertModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
ElectraModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -2159,7 +2159,7 @@ def __init__(

if use_pretrained and not saved_weights_in_checkpoint:
pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
LongformerModel, pretrained_model_name_or_path, **pretrained_kwargs
)
else:
Expand Down Expand Up @@ -2243,7 +2243,7 @@ def __init__(
from transformers import AutoModel

pretrained_kwargs = pretrained_kwargs or {}
transformer, _ = load_pretrained_hf_model_with_hub_fallback(
transformer, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModel, pretrained_model_name_or_path, **pretrained_kwargs
)
self._maybe_resize_token_embeddings(transformer, vocab_size)
Expand Down
21 changes: 13 additions & 8 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedModel
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PreTrainedModel

from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, MODEL_LLM, PREDICTIONS, TEXT
from ludwig.features.base_feature import ModuleWrapper, OutputFeature
Expand All @@ -19,17 +19,18 @@
from ludwig.utils.augmentation_utils import AugmentationPipelines
from ludwig.utils.data_utils import clear_data_cache
from ludwig.utils.error_handling_utils import default_retry
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback
from ludwig.utils.llm_utils import (
add_left_padding,
generate_merged_ids,
get_context_len,
pad_target_tensor_for_fine_tuning,
realign_target_and_prediction_tensors_for_inference,
remove_left_padding,
set_pad_token,
)
from ludwig.utils.logging_utils import log_once
from ludwig.utils.output_feature_utils import set_output_feature_tensor
from ludwig.utils.tokenizers import HFTokenizer
from ludwig.utils.torch_utils import reg_loss

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,7 +102,9 @@ def load_pretrained_from_config(

logger.info("Loading large language model...")
pretrained_model_name_or_path = weights_save_path or config_obj.base_model
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM, pretrained_model_name_or_path, **load_kwargs
)
return model


Expand All @@ -123,7 +126,7 @@ def __init__(
self._random_seed = random_seed

self.model_name = self.config_obj.base_model
self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model)
self.model_config, _ = load_pretrained_hf_class_with_hub_fallback(AutoConfig, self.config_obj.base_model)

self.model = load_pretrained_from_config(self.config_obj, model_config=self.model_config)
self.curr_device = next(self.model.parameters()).device
Expand All @@ -144,8 +147,8 @@ def __init__(
self.global_max_sequence_length = self.context_len

# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.config_obj.base_model)
set_pad_token(self.tokenizer)
ludwig_tokenizer = HFTokenizer(self.config_obj.base_model)
self.tokenizer = ludwig_tokenizer.tokenizer

self._set_generation_config(self.config_obj.generation.to_dict())

Expand Down Expand Up @@ -300,7 +303,8 @@ def to_device(self, device):
if self.config_obj.adapter:
from peft import PeftModel

self.model = AutoModelForCausalLM.from_pretrained(
self.model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM,
self.model_name,
**model_kwargs,
)
Expand All @@ -310,7 +314,8 @@ def to_device(self, device):
torch_dtype=torch.float16,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM,
tmpdir,
**model_kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion ludwig/schema/llms/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ludwig.error import ConfigValidationError
from ludwig.schema.metadata import LLM_METADATA
from ludwig.schema.metadata.parameter_metadata import convert_metadata_to_json
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback

# Maps a preset LLM name to the full slash-delimited HF path. If the user chooses a preset LLM, the preset LLM name is
# replaced with the full slash-delimited HF path using this map, after JSON validation but before config object
Expand Down Expand Up @@ -55,7 +56,7 @@ def validate(model_name: str):
if os.path.isdir(model_name):
return model_name
try:
AutoConfig.from_pretrained(model_name)
load_pretrained_hf_class_with_hub_fallback(AutoConfig, model_name)
return model_name
except OSError:
raise ConfigValidationError(
Expand Down
3 changes: 2 additions & 1 deletion ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name
from ludwig.utils.hf_utils import load_pretrained_hf_class_with_hub_fallback
from ludwig.utils.llm_utils import get_context_len

if TYPE_CHECKING:
Expand Down Expand Up @@ -370,7 +371,7 @@ def _get_maximum_possible_sequence_length(config: "ModelConfig", default_max_seq
# we should fall back to the window size of the pretrained model. By this point, because of schema validation
# checks, we know that the base_model exists so we can safely grab the base model's config.
# TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation.
model_config = AutoConfig.from_pretrained(config.base_model)
model_config, _ = load_pretrained_hf_class_with_hub_fallback(AutoConfig, config.base_model)
max_possible_sequence_length = get_context_len(model_config)
# Artifically leave a buffer of half the total model window size to trade off
# runtime while likely covering a majority of the max sequence length.
Expand Down
Loading
Loading