From 2e0a845863dc18ad8ddac71fc7f18a9940989891 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 5 Feb 2024 14:00:38 -0800 Subject: [PATCH] Retrieve license information when local files are provided for a pretrained model (#943) * Initial implementation to test * Add log for license overwrite * Use Path for input to _write_license_information * Set default --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 36 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 1ece1bff75..c0db8e1c28 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -20,6 +20,7 @@ maybe_create_remote_uploader_downloader_from_uri, parse_uri) from composer.utils.misc import create_interval_scheduler +from mlflow.transformers import _fetch_model_card, _write_license_information from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM @@ -32,17 +33,41 @@ _LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE) -def _maybe_get_license_filename(local_dir: str) -> Optional[str]: +def _maybe_get_license_filename( + local_dir: str, + pretrained_model_name: Optional[str] = None) -> Optional[str]: """Returns the name of the license file if it exists in the local_dir. Note: This is intended to be consistent with the code in MLflow. https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152 + Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub, + MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for + a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name, + in which case this function will use that to fetch the correct license information. + If the license file does not exist, returns None. """ try: - return next(file for file in os.listdir(local_dir) - if _LICENSE_FILE_PATTERN.search(file)) + license_filename = next(file for file in os.listdir(local_dir) + if _LICENSE_FILE_PATTERN.search(file)) + + # If a pretrained model name is provided, replace the license file with the correct info from HF Hub. + if pretrained_model_name is not None: + log.info( + f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub' + ) + os.remove(os.path.join(local_dir, license_filename)) + model_card = _fetch_model_card(pretrained_model_name) + + local_dir_path = Path(local_dir).absolute() + _write_license_information(pretrained_model_name, model_card, + local_dir_path) + license_filename = next(file for file in os.listdir(local_dir) + if _LICENSE_FILE_PATTERN.search(file)) + + return license_filename + except StopIteration: return None @@ -330,8 +355,11 @@ def _save_checkpoint(self, state: State, logger: Logger): mlflow_logger.save_model(**model_saving_kwargs) + # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename( - local_save_path) + local_save_path, + self.mlflow_logging_config['metadata'].get( + 'pretrained_model_name', None)) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact( mlflow_logger._run_id,