Skip to content

Commit

Permalink
Retrieve license information when local files are provided for a pret…
Browse files Browse the repository at this point in the history
…rained model (#943)

* Initial implementation to test

* Add log for license overwrite

* Use Path for input to _write_license_information

* Set default

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
jerrychen109 and dakinggg authored Feb 5, 2024
1 parent 3f21bb7 commit 2e0a845
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand All @@ -32,17 +33,41 @@
_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
def _maybe_get_license_filename(
local_dir: str,
pretrained_model_name: Optional[str] = None) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
in which case this function will use that to fetch the correct license information.
If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

# If a pretrained model name is provided, replace the license file with the correct info from HF Hub.
if pretrained_model_name is not None:
log.info(
f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub'
)
os.remove(os.path.join(local_dir, license_filename))
model_card = _fetch_model_card(pretrained_model_name)

local_dir_path = Path(local_dir).absolute()
_write_license_information(pretrained_model_name, model_card,
local_dir_path)
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

return license_filename

except StopIteration:
return None

Expand Down Expand Up @@ -330,8 +355,11 @@ def _save_checkpoint(self, state: State, logger: Logger):

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path)
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
Expand Down

0 comments on commit 2e0a845

Please sign in to comment.