From c1c4bbfb0f151058dcb0b3341cc607b7146ea981 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 6 Feb 2024 09:54:37 -0800 Subject: [PATCH] Add and use VersionedDeprecationWarning (#944) * Add and use VersionedDeprecationWarning * Use remove_version instead. * Fix merge * Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/__init__.py | 2 +- llmfoundry/callbacks/generate_callback.py | 10 ++++--- llmfoundry/data/packing.py | 9 ++++--- llmfoundry/models/layers/attention.py | 31 +++++++++++++--------- llmfoundry/models/mpt/configuration_mpt.py | 12 ++++++--- llmfoundry/utils/warnings.py | 27 +++++++++++++++++++ scripts/train/train.py | 8 ++++-- 7 files changed, 73 insertions(+), 26 deletions(-) create mode 100644 llmfoundry/utils/warnings.py diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 87504d26b3..b7e69ff31d 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -95,4 +95,4 @@ 'TiktokenTokenizerWrapper', ] -__version__ = '0.4.0' +__version__ = '0.5.0' diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 58ba7e685e..f144c9dd75 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -11,6 +11,8 @@ from composer.callbacks import Generate as ComposerGenerate from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from llmfoundry.utils.warnings import VersionedDeprecationWarning + Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -20,10 +22,10 @@ def __init__(self, prompts: List[str], batch_log_interval: int, **kwargs: Any): warnings.warn( - ('Accessing llmfoundry.callbacks.generate_callback.Generate ' - 'is deprecated and will be removed in a future release. ' - 'Please use composer.callbacks.Generate instead.'), - DeprecationWarning, + VersionedDeprecationWarning('Accessing llmfoundry.callbacks.generate_callback.Generate ' + \ + 'is deprecated. Please use composer.callbacks.Generate instead.', + remove_version='0.5.0', + ) ) interval = f'{batch_log_interval}ba' diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index d3084c72c8..0a09e1465b 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -11,6 +11,8 @@ from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase +from llmfoundry.utils.warnings import VersionedDeprecationWarning + log = logging.getLogger(__name__) @@ -433,9 +435,10 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: import warnings warnings.warn( - DeprecationWarning( - 'Please use scripts/misc/profile_packing.py to profile packing.' + - 'This script will be removed in later releases.')) + VersionedDeprecationWarning( + 'Please use scripts/misc/profile_packing.py to profile packing.', + remove_version='0.5.0', + )) import os from argparse import ArgumentParser, Namespace diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 89f861c3f0..42f9403868 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -16,6 +16,7 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.utils.warnings import VersionedDeprecationWarning def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -104,14 +105,16 @@ def scaled_multihead_dot_product_attention( torch.Tensor]]]: if multiquery: warnings.warn( - DeprecationWarning( - 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + VersionedDeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = 1 elif kv_n_heads is None: warnings.warn( - DeprecationWarning( - 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + VersionedDeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = n_heads @@ -249,14 +252,16 @@ def flash_attn_fn( if multiquery: warnings.warn( - DeprecationWarning( - 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + VersionedDeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = 1 elif kv_n_heads is None: warnings.warn( - DeprecationWarning( - 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + VersionedDeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = n_heads @@ -422,14 +427,16 @@ def triton_flash_attn_fn( if multiquery: warnings.warn( - DeprecationWarning( - 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + VersionedDeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = 1 elif kv_n_heads is None: warnings.warn( - DeprecationWarning( - 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + VersionedDeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.', + remove_version='0.5.0', )) kv_n_heads = n_heads diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index e666897863..bc2c155e4d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -21,6 +21,8 @@ from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.utils.warnings import VersionedDeprecationWarning + ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', } @@ -159,8 +161,9 @@ def __init__( self.use_pad_tok_in_ffn = use_pad_tok_in_ffn if verbose is not None: warnings.warn( - DeprecationWarning( - 'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.' + VersionedDeprecationWarning( + 'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.', + remove_version='0.5.0', )) if 'name' in kwargs: @@ -226,8 +229,9 @@ def _validate_config(self) -> None: if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed(): warnings.warn( - DeprecationWarning( - 'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.' + VersionedDeprecationWarning( + 'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.', + remove_version='0.6.0', )) if self.attn_config[ diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py new file mode 100644 index 0000000000..368c5725c3 --- /dev/null +++ b/llmfoundry/utils/warnings.py @@ -0,0 +1,27 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + + +class VersionedDeprecationWarning(DeprecationWarning): + """A custom deprecation warning class that includes version information. + + Attributes: + message (str): The deprecation message describing why the feature is deprecated. + remove_version (str): The version in which the feature will be removed. + + Example: + >>> def deprecated_function(): + ... warnings.warn( + ... VersionedDeprecationWarning( + ... "Function XYZ is deprecated.", + ... after_version="2.0.0" + ... ) + ... ) + ... + >>> deprecated_function() + DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0. + """ + + def __init__(self, message: str, remove_version: str) -> None: + super().__init__(message + + f' It will be removed in version {remove_version}.') diff --git a/scripts/train/train.py b/scripts/train/train.py index dbaaf13ebc..0b89d0cc08 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -22,6 +22,8 @@ from omegaconf import OmegaConf as om from rich.traceback import install +from llmfoundry.utils.warnings import VersionedDeprecationWarning + install() from transformers import PreTrainedTokenizerBase @@ -219,8 +221,10 @@ def main(cfg: DictConfig) -> Trainer: default_value=None) if eval_gauntlet_config is not None: warnings.warn( - 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`', - DeprecationWarning) + VersionedDeprecationWarning( + 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`.', + remove_version='0.5.0', + )) icl_subset_num_batches: Optional[int] = pop_config(cfg, 'icl_subset_num_batches', must_exist=False,