Skip to content

Commit

Permalink
Add and use VersionedDeprecationWarning (#944)
Browse files Browse the repository at this point in the history
* Add and use VersionedDeprecationWarning

* Use remove_version instead.

* Fix merge

* Apply suggestions from code review

Co-authored-by: Daniel King <[email protected]>

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
irenedea and dakinggg authored Feb 6, 2024
1 parent 2e0a845 commit c1c4bbf
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 26 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.4.0'
__version__ = '0.5.0'
10 changes: 6 additions & 4 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from llmfoundry.utils.warnings import VersionedDeprecationWarning

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


Expand All @@ -20,10 +22,10 @@ def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):

warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
VersionedDeprecationWarning('Accessing llmfoundry.callbacks.generate_callback.Generate ' + \
'is deprecated. Please use composer.callbacks.Generate instead.',
remove_version='0.5.0',
)
)

interval = f'{batch_log_interval}ba'
Expand Down
9 changes: 6 additions & 3 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -433,9 +435,10 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
import warnings

warnings.warn(
DeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.' +
'This script will be removed in later releases.'))
VersionedDeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.',
remove_version='0.5.0',
))

import os
from argparse import ArgumentParser, Namespace
Expand Down
31 changes: 19 additions & 12 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.utils.warnings import VersionedDeprecationWarning


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -104,14 +105,16 @@ def scaled_multihead_dot_product_attention(
torch.Tensor]]]:
if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -249,14 +252,16 @@ def flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -422,14 +427,16 @@ def triton_flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down
12 changes: 8 additions & 4 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)

from llmfoundry.utils.warnings import VersionedDeprecationWarning

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
}
Expand Down Expand Up @@ -159,8 +161,9 @@ def __init__(
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
VersionedDeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.',
remove_version='0.5.0',
))

if 'name' in kwargs:
Expand Down Expand Up @@ -226,8 +229,9 @@ def _validate_config(self) -> None:

if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
warnings.warn(
DeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.'
VersionedDeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.',
remove_version='0.6.0',
))

if self.attn_config[
Expand Down
27 changes: 27 additions & 0 deletions llmfoundry/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0


class VersionedDeprecationWarning(DeprecationWarning):
"""A custom deprecation warning class that includes version information.
Attributes:
message (str): The deprecation message describing why the feature is deprecated.
remove_version (str): The version in which the feature will be removed.
Example:
>>> def deprecated_function():
... warnings.warn(
... VersionedDeprecationWarning(
... "Function XYZ is deprecated.",
... after_version="2.0.0"
... )
... )
...
>>> deprecated_function()
DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0.
"""

def __init__(self, message: str, remove_version: str) -> None:
super().__init__(message +
f' It will be removed in version {remove_version}.')
8 changes: 6 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from omegaconf import OmegaConf as om
from rich.traceback import install

from llmfoundry.utils.warnings import VersionedDeprecationWarning

install()

from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -219,8 +221,10 @@ def main(cfg: DictConfig) -> Trainer:
default_value=None)
if eval_gauntlet_config is not None:
warnings.warn(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`',
DeprecationWarning)
VersionedDeprecationWarning(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`.',
remove_version='0.5.0',
))
icl_subset_num_batches: Optional[int] = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
Expand Down

0 comments on commit c1c4bbf

Please sign in to comment.