From 279a9ba1cb516948bf71754f973bfc2d7868fee2 Mon Sep 17 00:00:00 2001 From: Kabir Brar Date: Tue, 25 Jul 2023 10:05:33 -0700 Subject: [PATCH 1/3] int: Fix quantization schema (#3479) --- ludwig/schema/llms/quantization.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ludwig/schema/llms/quantization.py b/ludwig/schema/llms/quantization.py index 04eae3380b3..2812f59afa9 100644 --- a/ludwig/schema/llms/quantization.py +++ b/ludwig/schema/llms/quantization.py @@ -86,4 +86,15 @@ def __init__(self): super().__init__(QuantizationConfig, default_missing=True) def _jsonschema_type_mapping(self): - return schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig) + return { + "oneOf": [ + {"type": "null", "title": "disabled", "description": "Disable quantization."}, + { + **schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig), + "title": "enabled", + "description": "Set quantization options.", + }, + ], + "title": "quantization", + "description": "", + } From 446c747a1e18e4e1897aadd68fe0aa87f661226a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 26 Jul 2023 15:55:29 -0700 Subject: [PATCH 2/3] Fixed divide by zero when tuning batch size (#3481) --- ludwig/utils/batch_size_tuner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ludwig/utils/batch_size_tuner.py b/ludwig/utils/batch_size_tuner.py index d1a0c857253..74ccc777b51 100644 --- a/ludwig/utils/batch_size_tuner.py +++ b/ludwig/utils/batch_size_tuner.py @@ -100,7 +100,11 @@ def evaluate(self, batch_size: int, total_steps: int = 5) -> float: start_ts = time.time() self.step(batch_size) durations.append(time.time() - start_ts) + med_duration_s = statistics.median(durations) + if med_duration_s == 0.0: + return float("inf") + return batch_size / med_duration_s def reset(self): From 539e8e056841204570c1d21498de2ab30a9baaae Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 28 Jul 2023 09:48:49 -0700 Subject: [PATCH 3/3] Exclude frozen weights from checkpoints and fix evaluation using quantized LLMs (#3483) --- ludwig/data/dataset/pandas.py | 9 ++++++- ludwig/distributed/deepspeed.py | 1 + ludwig/models/llm.py | 44 +++++++++++++++++--------------- ludwig/trainers/trainer.py | 3 ++- ludwig/utils/checkpoint_utils.py | 15 +++++++++-- 5 files changed, 48 insertions(+), 24 deletions(-) diff --git a/ludwig/data/dataset/pandas.py b/ludwig/data/dataset/pandas.py index 7b495fa418e..6cec6ec6bf8 100644 --- a/ludwig/data/dataset/pandas.py +++ b/ludwig/data/dataset/pandas.py @@ -25,7 +25,7 @@ from ludwig.data.sampler import DistributedSampler from ludwig.distributed import DistributedStrategy from ludwig.features.base_feature import BaseFeature -from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP, save_hdf5 +from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP, load_hdf5, save_hdf5 from ludwig.utils.dataframe_utils import from_numpy_dataset, to_numpy_dataset, to_scalar_df from ludwig.utils.defaults import default_random_seed from ludwig.utils.fs_utils import download_h5 @@ -37,6 +37,9 @@ def __init__(self, dataset, features, data_hdf5_fp): self.features = features self.data_hdf5_fp = data_hdf5_fp self.size = len(dataset) + + if isinstance(dataset, str): + dataset = load_hdf5(dataset) self.dataset = to_numpy_dataset(dataset) def to_df(self, features: Optional[Iterable[BaseFeature]] = None) -> DataFrame: @@ -79,6 +82,10 @@ def get_dataset(self): def __len__(self): return self.size + @property + def processed_data_fp(self) -> Optional[str]: + return self.data_hdf5_fp + @property def in_memory_size_bytes(self): df = self.to_df() diff --git a/ludwig/distributed/deepspeed.py b/ludwig/distributed/deepspeed.py index f92577f1753..c708c1865be 100644 --- a/ludwig/distributed/deepspeed.py +++ b/ludwig/distributed/deepspeed.py @@ -210,6 +210,7 @@ def save(self, save_path: str, global_step: int): if self.scheduler is not None: client_state["scheduler_state"] = self.scheduler.state_dict() + # TODO: set exclude_frozen_parameters=True to only save PEFT weights self.model.save_checkpoint(save_path, client_state=client_state) def get_state_for_inference(self, save_path: str, device: Optional[torch.device] = None) -> Mapping[str, Any]: diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 92a2557c6ee..9d67904e371 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -6,7 +6,7 @@ import numpy as np import torch -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaConfig, PreTrainedModel from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, MODEL_LLM, PREDICTIONS, TEXT from ludwig.features.base_feature import ModuleWrapper, OutputFeature @@ -71,6 +71,19 @@ def update(self, modules: Dict[str, torch.nn.Module]) -> None: self.obj.update(modules) +def load_pretrained_from_config(config_obj: LLMModelConfig, weights_save_path: Optional[str] = None) -> PreTrainedModel: + load_kwargs = {} + if config_obj.quantization: + # Apply quanitzation configuration at model load time + load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype) + load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes() + + logger.info("Loading large language model...") + pretrained_model_name_or_path = weights_save_path or config_obj.base_model + model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **load_kwargs) + return model + + class LLM(BaseModel): @staticmethod def type() -> str: @@ -91,17 +104,8 @@ def __init__( self.model_name = self.config_obj.base_model self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model) - self.load_kwargs = {} - if self.config_obj.quantization: - # Apply quanitzation configuration at model load time - self.load_kwargs["torch_dtype"] = getattr(torch, self.config_obj.quantization.bnb_4bit_compute_dtype) - self.load_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes() - - logger.info("Loading large language model...") - self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model, **self.load_kwargs) - - # Model initially loaded onto cpu - self.curr_device = torch.device("cpu") + self.model = load_pretrained_from_config(self.config_obj) + self.curr_device = next(self.model.parameters()).device logger.info("Done.") # Determines the maximum length of the context (input + output tokens) @@ -312,7 +316,7 @@ def forward( # Wrap with flash attention backend for faster generation with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) if ( - torch.cuda.is_available() and next(self.model.parameters()).device.type == "cuda" + torch.cuda.is_available() and self.curr_device.type == "cuda" ) else contextlib.nullcontext(): model_outputs = self.model(input_ids=self.model_inputs, attention_mask=self.attention_masks).get(LOGITS) @@ -373,9 +377,7 @@ def generate( # Wrap with flash attention backend for faster generation with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False - ) if ( - torch.cuda.is_available() and next(self.model.parameters()).device.type == "cuda" - ) else contextlib.nullcontext(): + ) if (torch.cuda.is_available() and self.curr_device.type == "cuda") else contextlib.nullcontext(): # Generate text using the model model_outputs = self.model.generate( input_ids=input_ids_sample_no_padding, @@ -573,13 +575,15 @@ def load(self, save_path): """Loads the model from the given path.""" weights_save_path = os.path.join(save_path, MODEL_WEIGHTS_FILE_NAME) if self.config_obj.adapter: - from peft import PeftConfig, PeftModel # noqa + from peft import PeftModel # noqa + + if isinstance(self.model, PeftModel): + # Unwrap and reload PeftModel + self.model = self.model.base_model - config = PeftConfig.from_pretrained(weights_save_path) - self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) self.model = PeftModel.from_pretrained(self.model, weights_save_path) elif self.config_obj.trainer.type != "none": - self.model = AutoModelForCausalLM.from_pretrained(weights_save_path) + self.model = load_pretrained_from_config(self.config_obj, weights_save_path) else: logger.info("Skipped loading LLM without weight adjustments.") diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 5c806e83779..7e6bb3d1294 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -823,7 +823,8 @@ def train( if self.distributed.is_model_parallel(): # Assume the full weights cannot fit in memory on GPU self.model = self.model.cpu() - self.model.load_state_dict(state_dict) + _, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) + assert unexpected_keys == [], f"Unexpected keys found in state dict: {unexpected_keys}" elif return_state_dict: state_dict = self.model.cpu().state_dict() diff --git a/ludwig/utils/checkpoint_utils.py b/ludwig/utils/checkpoint_utils.py index 756b175205c..ebffa1e1a0b 100644 --- a/ludwig/utils/checkpoint_utils.py +++ b/ludwig/utils/checkpoint_utils.py @@ -139,7 +139,8 @@ def load(self, save_path: str, device: Optional[torch.device] = None) -> bool: state = torch.load(save_path, map_location=device) try: self.global_step = self._get_global_step(state, save_path) - self.model.load_state_dict(state["model_weights"]) + _, unexpected_keys = self.model.load_state_dict(state["model_weights"], strict=False) + assert unexpected_keys == [], f"Unexpected keys found in state dict: {unexpected_keys}" if self.optimizer is not None: self.optimizer.load_state_dict(state["optim_state"]) if self.scheduler is not None and "scheduler_state" in state: @@ -175,7 +176,7 @@ def save(self, save_path: str, global_step: int): if self.is_local_rank_0(): state = { "global_step": global_step, - "model_weights": self.model.state_dict(), + "model_weights": self.get_model_state_dict(), } if self.optimizer is not None: state["optim_state"] = self.optimizer.state_dict() @@ -208,6 +209,16 @@ def save(self, save_path: str, global_step: int): signal.signal(signal.SIGINT, orig_handler) self.distributed.barrier() + def get_model_state_dict(self) -> Dict[str, Any]: + state = self.model.state_dict() + + # Remove frozen parameter weights from state_dict for adapters and pretrained models + for n, p in self.model.named_parameters(): + if not p.requires_grad: + del state[n] + + return state + def is_local_rank_0(self) -> bool: return self.distributed.local_rank() == 0