Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Oct 27, 2023
1 parent e42d3e3 commit b7ac42c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def load_pretrained_from_config(

logger.info("Loading large language model...")
pretrained_model_name_or_path = weights_save_path or config_obj.base_model
model: PreTrainedModel = load_pretrained_hf_class_with_hub_fallback(
model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM, pretrained_model_name_or_path, **load_kwargs
)
return model
Expand All @@ -126,7 +126,7 @@ def __init__(
self._random_seed = random_seed

self.model_name = self.config_obj.base_model
self.model_config = load_pretrained_hf_class_with_hub_fallback(AutoConfig, self.config_obj.base_model)
self.model_config, _ = load_pretrained_hf_class_with_hub_fallback(AutoConfig, self.config_obj.base_model)

self.model = load_pretrained_from_config(self.config_obj, model_config=self.model_config)
self.curr_device = next(self.model.parameters()).device
Expand Down Expand Up @@ -303,7 +303,7 @@ def to_device(self, device):
if self.config_obj.adapter:
from peft import PeftModel

self.model = load_pretrained_hf_class_with_hub_fallback(
self.model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM,
self.model_name,
**model_kwargs,
Expand All @@ -314,7 +314,7 @@ def to_device(self, device):
torch_dtype=torch.float16,
)
else:
self.model = load_pretrained_hf_class_with_hub_fallback(
self.model, _ = load_pretrained_hf_class_with_hub_fallback(
AutoModelForCausalLM,
tmpdir,
**model_kwargs,
Expand Down Expand Up @@ -698,7 +698,7 @@ def load(self, save_path):
# Unwrap and reload PeftModel
self.model = self.model.base_model

self.model = load_pretrained_hf_class_with_hub_fallback(PeftModel, self.model, weights_save_path)
self.model = PeftModel.from_pretrained(self.model, weights_save_path)
elif self.config_obj.trainer.type != "none":
self.model = load_pretrained_from_config(
self.config_obj, model_config=self.model_config, weights_save_path=weights_save_path
Expand Down

0 comments on commit b7ac42c

Please sign in to comment.