Skip to content

Commit

Permalink
👈 Add tokenizer arg back and add deprecation guidelines (#2348)
Browse files Browse the repository at this point in the history
* Add deprecation and backward compatibility guidelines

* Update tokenizer argument in trainer classes

* Add warning message for TRL Judges API
  • Loading branch information
qgallouedec committed Nov 11, 2024
1 parent 14ef1ab commit f662824
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 2 deletions.
27 changes: 27 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,30 @@ That's how `make test` is implemented (without the `pip install` line)!

You can specify a smaller set of tests to test only the feature
you're working on.
### Deprecation and Backward Compatibility
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
Example:
```python
warnings.warn(
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
"Please use the `Trainer.bar` class instead.",
FutureWarning,
)
```
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:

- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.

- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.

These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
6 changes: 6 additions & 0 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Judges

<Tip warning={true}>

TRL Judges is an experimental API which is subject to change at any time.

</Tip>

TRL provides judges to easily compare two completions.

Make sure to have installed the required dependencies by running:
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template
from ..models import PreTrainedModelWrapper, create_reference_model
Expand Down Expand Up @@ -317,6 +318,7 @@ class BCOTrainer(Trainer):

_tag_names = ["trl", "bco"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
Expand Down Expand Up @@ -103,6 +104,7 @@ class CPOTrainer(Trainer):

_tag_names = ["trl", "cpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class DPOTrainer(Trainer):
],
custom_message="Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.",
)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..models import PreTrainedModelWrapper
from ..models.utils import unwrap_model_for_generation
Expand All @@ -61,6 +62,7 @@
class GKDTrainer(SFTTrainer):
_tag_names = ["trl", "gkd"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..core import PPODecorators
from .utils import generate_model_card
Expand Down Expand Up @@ -80,6 +81,7 @@ class IterativeSFTTrainer(Trainer):

_tag_names = ["trl", "iterative-sft"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[PreTrainedModel] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..models import PreTrainedModelWrapper, create_reference_model
Expand Down Expand Up @@ -312,6 +313,7 @@ class KTOTrainer(Trainer):

_tag_names = ["trl", "kto"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_apex_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.modeling_base import GeometricMixtureWrapper
Expand Down Expand Up @@ -93,6 +94,7 @@ class NashMDTrainer(OnlineDPOTrainer):

_tag_names = ["trl", "nash-md"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model
Expand Down Expand Up @@ -125,6 +126,7 @@ class OnlineDPOTrainer(Trainer):

_tag_names = ["trl", "online-dpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module],
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper
Expand Down Expand Up @@ -114,6 +115,7 @@ class ORPOTrainer(Trainer):

_tag_names = ["trl", "orpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils.deprecation import deprecate_kwarg

from ..core import masked_mean, masked_whiten
from ..models.utils import unwrap_model_for_generation
Expand Down Expand Up @@ -90,6 +91,7 @@ def forward(self, **kwargs):
class PPOTrainer(Trainer):
_tag_names = ["trl", "ppo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
config: PPOConfig,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template
from .reward_config import RewardConfig
Expand Down Expand Up @@ -80,6 +81,7 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase")
class RewardTrainer(Trainer):
_tag_names = ["trl", "reward-trainer"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils.deprecation import deprecate_kwarg

from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
Expand Down Expand Up @@ -71,6 +72,7 @@
class RLOOTrainer(Trainer):
_tag_names = ["trl", "rloo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
def __init__(
self,
config: RLOOConfig,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class SFTTrainer(Trainer):
],
custom_message="Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.",
)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.utils import unwrap_model_for_generation
Expand Down Expand Up @@ -92,6 +93,7 @@ class XPOTrainer(OnlineDPOTrainer):

_tag_names = ["trl", "xpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
Expand Down

0 comments on commit f662824

Please sign in to comment.