diff --git a/TUTORIAL.md b/TUTORIAL.md index 9563fa4191..36993bc409 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -342,7 +342,7 @@ lora: ``` - In the current release, these features have Beta support. - For efficiency, The MPT model concatenates the `Q`, `K`, and `V` matrices in each attention block into a single `Wqkv` matrix that is three times wider. Currently, LoRA supports a low-rank approximation to this `Wqkv` matrix. -- Known issue: PEFT / LoRA do not directly work with FSDP. +- When evaluating with PEFT / LoRA seperated weight, just set `pretrained_lora_id_or_path` in `model`(Find an example [here](scripts/eval/yamls/hf_lora_eval.yml#L19)). ### Can I quantize these models and/or run on CPU? - The LLM Foundry codebase does not directly have examples of quantization or limited-resource inference. But you can check out [GGML](https://github.com/ggerganov/ggml) (same library that powers llama.cpp) which has built support for efficiently running MPT models on CPU! You _can_ load your model in 8-bit precision for inference using the [bitsandbytes library](https://github.com/TimDettmers/bitsandbytes) and Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index) via `load model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto", trust_remote_code=True)`, although we have not extensively benchmarked the performance (see the Hugging Face [quantization documentation](https://huggingface.co/docs/transformers/main/main_classes/quantization) for more detail). diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 58dcbc940a..60c8656b5e 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -15,15 +15,60 @@ from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om -from transformers import PreTrainedTokenizerBase +from transformers import (AutoModelForCausalLM, PreTrainedTokenizerBase, + T5ForConditionalGeneration) from llmfoundry.callbacks import ModelGauntlet from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY +from llmfoundry.models.mpt import MPTForCausalLM from llmfoundry.utils.builders import (build_icl_evaluators, build_logger, build_tokenizer) from llmfoundry.utils.config_utils import process_init_device +def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + num_retries: int) -> Optional[ComposerModel]: + try: + from peft import PeftModel + except ImportError as e: + raise ImportError( + f'Error importing from peft. Run `pip install -e .[gpu,peft]`. \n {e}' + ) + + model_registry = { + 'mpt_causal_lm': MPTForCausalLM, + 'hf_causal_lm': AutoModelForCausalLM, + 'hf_prefix_lm': AutoModelForCausalLM, + 'hf_t5': T5ForConditionalGeneration, + } + + retries = 0 + while retries < num_retries: + try: + trust_remote_code = model_cfg.get('trust_remote_code', True) + use_auth_token = model_cfg.get('use_auth_token', False) + model = model_registry[model_cfg.name].from_pretrained( + model_cfg.pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + ) + + peft_model = PeftModel.from_pretrained( + model, model_cfg.pretrained_lora_id_or_path) + + composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name](peft_model, + tokenizer) + return composer_model + except Exception as e: + retries += 1 + if retries >= num_retries: + raise e + else: + print( + f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' + ) + + def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, fsdp_config: Optional[Dict], num_retries: int) -> Optional[ComposerModel]: @@ -76,8 +121,12 @@ def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str, fsdp_config, resolve=True) if fsdp_config is not None else None assert isinstance(fsdp_config, Dict) or fsdp_config is None - composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, - cfg.get('num_retries', 3)) + if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'): + composer_model = load_peft_model(model_cfg.model, tokenizer, + cfg.get('num_retries', 3)) + else: + composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, + cfg.get('num_retries', 3)) if model_gauntlet_df is None and model_gauntlet is not None: model_gauntlet_df = pd.DataFrame( diff --git a/scripts/eval/yamls/hf_lora_eval.yml b/scripts/eval/yamls/hf_lora_eval.yml new file mode 100644 index 0000000000..9b8328c589 --- /dev/null +++ b/scripts/eval/yamls/hf_lora_eval.yml @@ -0,0 +1,48 @@ +max_seq_len: 2048 +seed: 1 +precision: amp_fp16 + +# If you are using one model, put it here: +model_name_or_path: EleutherAI/gpt-neo-125m +# If you are using a seperated lora weight, put it here: +lora_id_or_path: nathan0/lora-gpt-neo-125m-alpaca +# otherwise, write a block for each model you want to test in the `models` section + +models: +- + model_name: ${model_name_or_path} + model: + name: hf_causal_lm + pretrained_model_name_or_path: ${model_name_or_path} + init_device: cpu + pretrained: true + pretrained_lora_id_or_path: ${lora_id_or_path} + tokenizer: + name: ${model_name_or_path} + kwargs: + model_max_length: ${max_seq_len} +# # if you are evaluating more than one model, list them all as YAML blocks without variable interpolation +# - +# model_name: mosaicml/mpt-7b +# model: +# name: hf_causal_lm +# pretrained_model_name_or_path: mosaicml/mpt-7b +# init_device: cpu +# pretrained: true +# config_overrides: +# max_seq_len: ${max_seq_len} +# tokenizer: +# name: mosaicml/mpt-7b +# kwargs: +# model_max_length: ${max_seq_len} + + +device_eval_batch_size: 4 + +# FSDP config for model sharding +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: FULL + +icl_tasks: 'eval/yamls/tasks_light.yaml' +model_gauntlet: 'eval/yamls/model_gauntlet.yaml' diff --git a/setup.py b/setup.py index 0609c97143..9e65a14fd1 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ 'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes # TODO: pin peft when it stabilizes. # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI - 'peft@git+https://github.com/huggingface/peft.git', + 'peft==0.4.0', ] extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)