Skip to content

Commit

Permalink
Resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Jul 28, 2023
2 parents b663b47 + 539e8e0 commit 6d0963b
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 36 deletions.
9 changes: 8 additions & 1 deletion ludwig/data/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ludwig.data.sampler import DistributedSampler
from ludwig.distributed import DistributedStrategy
from ludwig.features.base_feature import BaseFeature
from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP, save_hdf5
from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP, load_hdf5, save_hdf5
from ludwig.utils.dataframe_utils import from_numpy_dataset, to_numpy_dataset, to_scalar_df
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.fs_utils import download_h5
Expand All @@ -37,6 +37,9 @@ def __init__(self, dataset, features, data_hdf5_fp):
self.features = features
self.data_hdf5_fp = data_hdf5_fp
self.size = len(dataset)

if isinstance(dataset, str):
dataset = load_hdf5(dataset)
self.dataset = to_numpy_dataset(dataset)

def to_df(self, features: Optional[Iterable[BaseFeature]] = None) -> DataFrame:
Expand Down Expand Up @@ -79,6 +82,10 @@ def get_dataset(self):
def __len__(self):
return self.size

@property
def processed_data_fp(self) -> Optional[str]:
return self.data_hdf5_fp

@property
def in_memory_size_bytes(self):
df = self.to_df()
Expand Down
1 change: 1 addition & 0 deletions ludwig/distributed/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def save(self, save_path: str, global_step: int):
if self.scheduler is not None:
client_state["scheduler_state"] = self.scheduler.state_dict()

# TODO: set exclude_frozen_parameters=True to only save PEFT weights
self.model.save_checkpoint(save_path, client_state=client_state)

def get_state_for_inference(self, save_path: str, device: Optional[torch.device] = None) -> Mapping[str, Any]:
Expand Down
71 changes: 40 additions & 31 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaConfig, PreTrainedModel

from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, MODEL_LLM, PREDICTIONS, TEXT
from ludwig.features.base_feature import ModuleWrapper, OutputFeature
Expand Down Expand Up @@ -71,6 +71,33 @@ def update(self, modules: Dict[str, torch.nn.Module]) -> None:
self.obj.update(modules)


def load_pretrained_from_config(
config_obj: LLMModelConfig,
model_config: Optional[AutoConfig] = None,
weights_save_path: Optional[str] = None,
) -> PreTrainedModel:
load_kwargs = {}
if config_obj.quantization:
# Apply quanitzation configuration at model load time
load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype)
load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes()

if config_obj.model_parameters:
# Add any model specific parameters to the load kwargs
for param_name, param_value in config_obj.model_parameters.to_config().items():
# Not all parameters are supported by all models, so we only add the parameter to the load kwargs
# if it is supported by the model.
if hasattr(model_config, param_name):
load_kwargs[param_name] = param_value
else:
logger.warning(f"Parameter {param_name} is not supported by {config_obj.base_model}. Skipping.")

logger.info("Loading large language model...")
pretrained_model_name_or_path = weights_save_path or config_obj.base_model
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
return model


class LLM(BaseModel):
@staticmethod
def type() -> str:
Expand All @@ -91,28 +118,8 @@ def __init__(
self.model_name = self.config_obj.base_model
self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model)

self.load_kwargs = {}
if self.config_obj.quantization:
# Apply quanitzation configuration at model load time
self.load_kwargs["torch_dtype"] = getattr(torch, self.config_obj.quantization.bnb_4bit_compute_dtype)
self.load_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes()

if self.config_obj.model_parameters:
# Add any model specific parameters to the load kwargs
for param_name, param_value in self.config_obj.model_parameters.to_config().items():
# Not all parameters are supported by all models, so we only add the parameter to the load kwargs
# if it is supported by the model.
if hasattr(self.model_config, param_name):
self.load_kwargs[param_name] = param_value
else:
logger.warning(
f"Parameter {param_name} is not supported by {self.config_obj.base_model}. Skipping."
)

logger.info("Loading large language model...")
self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model, **self.load_kwargs)
# Model initially loaded onto cpu
self.curr_device = torch.device("cpu")
self.model = load_pretrained_from_config(self.config_obj, model_config=self.model_config)
self.curr_device = next(self.model.parameters()).device
logger.info("Done.")

# Determines the maximum length of the context (input + output tokens)
Expand Down Expand Up @@ -323,7 +330,7 @@ def forward(

# Wrap with flash attention backend for faster generation
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) if (
torch.cuda.is_available() and next(self.model.parameters()).device.type == "cuda"
torch.cuda.is_available() and self.curr_device.type == "cuda"
) else contextlib.nullcontext():
model_outputs = self.model(input_ids=self.model_inputs, attention_mask=self.attention_masks).get(LOGITS)

Expand Down Expand Up @@ -384,9 +391,7 @@ def generate(
# Wrap with flash attention backend for faster generation
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
) if (
torch.cuda.is_available() and next(self.model.parameters()).device.type == "cuda"
) else contextlib.nullcontext():
) if (torch.cuda.is_available() and self.curr_device.type == "cuda") else contextlib.nullcontext():
# Generate text using the model
model_outputs = self.model.generate(
input_ids=input_ids_sample_no_padding,
Expand Down Expand Up @@ -584,13 +589,17 @@ def load(self, save_path):
"""Loads the model from the given path."""
weights_save_path = os.path.join(save_path, MODEL_WEIGHTS_FILE_NAME)
if self.config_obj.adapter:
from peft import PeftConfig, PeftModel # noqa
from peft import PeftModel # noqa

if isinstance(self.model, PeftModel):
# Unwrap and reload PeftModel
self.model = self.model.base_model

config = PeftConfig.from_pretrained(weights_save_path)
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
self.model = PeftModel.from_pretrained(self.model, weights_save_path)
elif self.config_obj.trainer.type != "none":
self.model = AutoModelForCausalLM.from_pretrained(weights_save_path)
self.model = load_pretrained_from_config(
self.config_obj, model_config=self.model_config, weights_save_path=weights_save_path
)
else:
logger.info("Skipped loading LLM without weight adjustments.")

Expand Down
13 changes: 12 additions & 1 deletion ludwig/schema/llms/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,15 @@ def __init__(self):
super().__init__(QuantizationConfig, default_missing=True)

def _jsonschema_type_mapping(self):
return schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig)
return {
"oneOf": [
{"type": "null", "title": "disabled", "description": "Disable quantization."},
{
**schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig),
"title": "enabled",
"description": "Set quantization options.",
},
],
"title": "quantization",
"description": "",
}
3 changes: 2 additions & 1 deletion ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,8 @@ def train(
if self.distributed.is_model_parallel():
# Assume the full weights cannot fit in memory on GPU
self.model = self.model.cpu()
self.model.load_state_dict(state_dict)
_, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
assert unexpected_keys == [], f"Unexpected keys found in state dict: {unexpected_keys}"
elif return_state_dict:
state_dict = self.model.cpu().state_dict()

Expand Down
4 changes: 4 additions & 0 deletions ludwig/utils/batch_size_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def evaluate(self, batch_size: int, total_steps: int = 5) -> float:
start_ts = time.time()
self.step(batch_size)
durations.append(time.time() - start_ts)

med_duration_s = statistics.median(durations)
if med_duration_s == 0.0:
return float("inf")

return batch_size / med_duration_s

def reset(self):
Expand Down
15 changes: 13 additions & 2 deletions ludwig/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def load(self, save_path: str, device: Optional[torch.device] = None) -> bool:
state = torch.load(save_path, map_location=device)
try:
self.global_step = self._get_global_step(state, save_path)
self.model.load_state_dict(state["model_weights"])
_, unexpected_keys = self.model.load_state_dict(state["model_weights"], strict=False)
assert unexpected_keys == [], f"Unexpected keys found in state dict: {unexpected_keys}"
if self.optimizer is not None:
self.optimizer.load_state_dict(state["optim_state"])
if self.scheduler is not None and "scheduler_state" in state:
Expand Down Expand Up @@ -175,7 +176,7 @@ def save(self, save_path: str, global_step: int):
if self.is_local_rank_0():
state = {
"global_step": global_step,
"model_weights": self.model.state_dict(),
"model_weights": self.get_model_state_dict(),
}
if self.optimizer is not None:
state["optim_state"] = self.optimizer.state_dict()
Expand Down Expand Up @@ -208,6 +209,16 @@ def save(self, save_path: str, global_step: int):
signal.signal(signal.SIGINT, orig_handler)
self.distributed.barrier()

def get_model_state_dict(self) -> Dict[str, Any]:
state = self.model.state_dict()

# Remove frozen parameter weights from state_dict for adapters and pretrained models
for n, p in self.model.named_parameters():
if not p.requires_grad:
del state[n]

return state

def is_local_rank_0(self) -> bool:
return self.distributed.local_rank() == 0

Expand Down

0 comments on commit 6d0963b

Please sign in to comment.