From f7e7e238a66baa59a6a2946394c1807da5078614 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Wed, 31 Jul 2024 15:01:07 +0530 Subject: [PATCH 1/6] bug: On save event added to callback (#256) * feat: On save event added to callback Signed-off-by: Padmanabha V Seshadri * fix: Removed additional bracket Signed-off-by: Padmanabha V Seshadri * fix: Removed additional bracket Signed-off-by: Padmanabha V Seshadri * fix: Format issues resolved Signed-off-by: Padmanabha V Seshadri * fix: rebase with upstream and add new line Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Padmanabha V Seshadri Signed-off-by: Mehant Kammakomati Co-authored-by: Mehant Kammakomati --- tests/data/trainercontroller/__init__.py | 1 + tests/data/trainercontroller/on-save.yaml | 10 ++++++++++ .../test_tuning_trainercontroller.py | 16 ++++++++++++++++ tuning/trainercontroller/callback.py | 14 ++++++++++++++ 4 files changed, 41 insertions(+) create mode 100644 tests/data/trainercontroller/on-save.yaml diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index aaaeabe9..35f4f13c 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -77,3 +77,4 @@ TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML = os.path.join( _DATA_DIR, "thresholded-training-loss.yaml" ) +TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml") diff --git a/tests/data/trainercontroller/on-save.yaml b/tests/data/trainercontroller/on-save.yaml new file mode 100644 index 00000000..a6fffb11 --- /dev/null +++ b/tests/data/trainercontroller/on-save.yaml @@ -0,0 +1,10 @@ +controller_metrics: + - name: state + class: TrainingState +controllers: + - name: stop_on_training_loss_on_save + triggers: + - on_save + rule: state["epoch"] >= 0.5 + operations: + - hfcontrols.should_training_stop diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index 7f98ace9..c4464da8 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -138,6 +138,22 @@ def test_thresholded_training_loss(): assert control.should_training_stop is True +def test_thresholded_training_loss_on_save(): + """Tests the thresholded training loss example in + `examples/trainer-controller-configs/on-save.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_TEST_ON_SAVE_YAML) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control) + assert control.should_training_stop is True + + def test_non_decreasing_training_loss(): """Tests the non-decreasing training loss example in `examples/trainer-controller-configs/non-decreasing-training-loss.yaml` diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index b7cd005b..0ca83305 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -548,3 +548,17 @@ def on_evaluate( kwargs["state"] = state kwargs["control"] = control self._actions_on_event(event_name="on_evaluate", **kwargs) + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_save", **kwargs) From 71c3e8a85a5684c81a8f72dbaccb8cfdfec53338 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Wed, 31 Jul 2024 15:27:09 +0530 Subject: [PATCH 2/6] feat: All metric handling changes (#263) * feat: All metric handling changes Signed-off-by: Padmanabha V Seshadri * fix: Format issues Signed-off-by: Padmanabha V Seshadri --------- Signed-off-by: Padmanabha V Seshadri --- examples/trainercontroller_configs/loss.yaml | 4 ++-- tests/data/trainercontroller/loss_custom_operation.yaml | 4 ++-- .../loss_custom_operation_invalid_action.yaml | 4 ++-- tests/data/trainercontroller/loss_invalid_metric.yaml | 4 ++-- tests/data/trainercontroller/loss_invalid_operation.yaml | 4 ++-- .../trainercontroller/loss_invalid_operation_action.yaml | 4 ++-- tests/data/trainercontroller/loss_invalid_trigger.yaml | 4 ++-- tests/data/trainercontroller/loss_on_threshold.yaml | 4 ++-- .../loss_on_threshold_with_trainer_state.yaml | 4 ++-- tests/data/trainercontroller/loss_unavailable_metric.yaml | 4 ++-- .../trainercontroller/loss_with_malicious_input_rule.yaml | 2 +- .../trainercontroller/loss_with_malicious_os_rule.yaml | 2 +- tests/trainercontroller/custom_operation.py | 7 ------- tests/trainercontroller/custom_operation_invalid_action.py | 7 ------- tuning/sft_trainer.py | 2 +- tuning/trainercontroller/callback.py | 1 + tuning/trainercontroller/controllermetrics/loss.py | 2 +- tuning/trainercontroller/operations/hfcontrols.py | 2 +- tuning/trainercontroller/operations/operation.py | 4 +++- 19 files changed, 29 insertions(+), 40 deletions(-) diff --git a/examples/trainercontroller_configs/loss.yaml b/examples/trainercontroller_configs/loss.yaml index d7d0baa2..c4322a6b 100644 --- a/examples/trainercontroller_configs/loss.yaml +++ b/examples/trainercontroller_configs/loss.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/data/trainercontroller/loss_custom_operation.yaml index 60345923..3ec952a8 100644 --- a/tests/data/trainercontroller/loss_custom_operation.yaml +++ b/tests/data/trainercontroller/loss_custom_operation.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - custom_operation.should_perform_action_xyz \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml index 3dac47cb..e0d3a71d 100644 --- a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml +++ b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation_invalid_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - custom_operation.should_ \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/data/trainercontroller/loss_invalid_metric.yaml index 4d94878a..8491175b 100644 --- a/tests/data/trainercontroller/loss_invalid_metric.yaml +++ b/tests/data/trainercontroller/loss_invalid_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: MissingMetricClass controllers: - name: loss_controller_invalid_metric triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/data/trainercontroller/loss_invalid_operation.yaml index f904e27d..769c9441 100644 --- a/tests/data/trainercontroller/loss_invalid_operation.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - missingop.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/data/trainercontroller/loss_invalid_operation_action.yaml index 3015516e..7d8a17ad 100644 --- a/tests/data/trainercontroller/loss_invalid_operation_action.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation_action.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.missingaction \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/data/trainercontroller/loss_invalid_trigger.yaml index 382ad778..38abe7ed 100644 --- a/tests/data/trainercontroller/loss_invalid_trigger.yaml +++ b/tests/data/trainercontroller/loss_invalid_trigger.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_trigger triggers: - log_it_all_incorrect_trigger_name - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/data/trainercontroller/loss_on_threshold.yaml index d7d0baa2..24891e8e 100644 --- a/tests/data/trainercontroller/loss_on_threshold.yaml +++ b/tests/data/trainercontroller/loss_on_threshold.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index 45e2a3ee..9dc764c4 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,12 +1,12 @@ controller_metrics: - name: state class: TrainingState - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 2 and state["epoch"] >= 0.5 + rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/data/trainercontroller/loss_unavailable_metric.yaml index 055b93cf..56418429 100644 --- a/tests/data/trainercontroller/loss_unavailable_metric.yaml +++ b/tests/data/trainercontroller/loss_unavailable_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_unavailable_metric triggers: - on_step_end - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml index 6d5c6532..e2cbb26d 100644 --- a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_input_rule diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml index badcf940..5ee4bc22 100644 --- a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/trainercontroller/custom_operation.py b/tests/trainercontroller/custom_operation.py index 2c402fa9..522200b4 100644 --- a/tests/trainercontroller/custom_operation.py +++ b/tests/trainercontroller/custom_operation.py @@ -26,13 +26,6 @@ class CustomOperation(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_perform_action_xyz(self, control: TrainerControl, **_): """This method performs a set training stop flag action. diff --git a/tests/trainercontroller/custom_operation_invalid_action.py b/tests/trainercontroller/custom_operation_invalid_action.py index 5c04199d..6871a64f 100644 --- a/tests/trainercontroller/custom_operation_invalid_action.py +++ b/tests/trainercontroller/custom_operation_invalid_action.py @@ -26,13 +26,6 @@ class CustomOperationInvalidAction(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_(self, control: TrainerControl, **_): """This method defines an action within an invalid name. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0e360ad4..30095b98 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -160,7 +160,7 @@ def train( trainer_controller_args.trainer_controller_config_file is not None ): tc_callback = TrainerControllerCallback( - trainer_controller_args.trainer_controller_config_file + trainer_controller_args.trainer_controller_config_file, ) trainer_callbacks.append(tc_callback) diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index 0ca83305..ebb661b3 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -258,6 +258,7 @@ def _take_control_actions(self, event_name: str, **kwargs): operation_action.instance.act( action=operation_action.action, event_name=event_name, + tc_metrics=self.metrics, **kwargs, ) diff --git a/tuning/trainercontroller/controllermetrics/loss.py b/tuning/trainercontroller/controllermetrics/loss.py index 2fd45014..543d6395 100644 --- a/tuning/trainercontroller/controllermetrics/loss.py +++ b/tuning/trainercontroller/controllermetrics/loss.py @@ -61,4 +61,4 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any: log = state.log_history[i] if "loss" not in log: continue - return float(log["loss"]) + return log diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py index 2bba9a1d..c1f7589e 100644 --- a/tuning/trainercontroller/operations/hfcontrols.py +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): for control_field in fields(TrainerControl): if re.search(r"^should_.+", control_field.name) is not None: setattr(self, control_field.name, self.control_action) - super().__init__() + super().__init__(**kwargs) def control_action(self, control: TrainerControl, **kwargs): """This method peeks into the stack-frame of the caller to get the action the triggered diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 916420e8..baa220c1 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -7,12 +7,14 @@ class Operation(metaclass=abc.ABCMeta): """Base class for operations""" - def __init__(self): + def __init__(self, name: str, **kwargs): """Initializes the HuggingFace controls. In this init, we follow the convention that every action should preceed with prefix `should_`. If so, it is treated as a valid action. """ self.valid_actions = {} + self.name = name + self.kwargs = kwargs for action_name, action_method in inspect.getmembers( self, predicate=inspect.ismethod ): From 72471ae8737d94a18d923dadbfb4b2f98d417665 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Wed, 31 Jul 2024 18:08:21 +0530 Subject: [PATCH 3/6] feat: Configuration to set logging level for trigger log (#241) * feat: Added the triggered login in the operation Signed-off-by: Padmanabha V Seshadri * fix: Formatting issues Signed-off-by: Padmanabha V Seshadri * fix: Added default config Signed-off-by: Padmanabha V Seshadri * fix: Moved the variable to right scope Signed-off-by: Padmanabha V Seshadri * fix: Checked added to validate config log level Signed-off-by: Padmanabha V Seshadri * fix: Removed some unwanted log file Signed-off-by: Padmanabha V Seshadri --------- Signed-off-by: Padmanabha V Seshadri --- ...ining-loss-below-threshold-log-config.yaml | 14 +++++++++ tuning/trainercontroller/callback.py | 31 +++++++++++++++---- .../trainercontroller/operations/operation.py | 19 +++++++++++- 3 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml diff --git a/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml new file mode 100644 index 00000000..53c8a177 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_stop_on_training_loss_below_threshold + triggers: + - on_step_end + rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2 + config: + trigger_log_level: warning + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index ebb661b3..2c5ddfcd 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -59,11 +59,14 @@ CONTROLLER_CONFIG_KEY = "config" CONTROLLER_PATIENCE_CONFIG_KEY = "patience" CONTROLLER_TRIGGERS_KEY = "triggers" +CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL = "trigger_log_level" CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY -# Default operations / metrics to register +# Default values DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]} DEFAULT_METRICS = {} +DEFAULT_CONFIG = {} +DEFAULT_TRIGGER_LOG_LEVEL = "debug" # pylint: disable=too-many-instance-attributes class TrainerControllerCallback(TrainerCallback): @@ -250,15 +253,14 @@ def _take_control_actions(self, event_name: str, **kwargs): continue if rule_succeeded: for operation_action in control_action.operation_actions: - logger.info( - "Taking [%s] action in controller [%s]", - operation_action.action, - control_action.name, - ) operation_action.instance.act( action=operation_action.action, event_name=event_name, tc_metrics=self.metrics, + control_name=control_action.name, + log_level=control_action.config[ + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL + ], **kwargs, ) @@ -303,6 +305,7 @@ def on_init_end( kwargs["state"] = state kwargs["control"] = control + log_levels = logging.get_log_levels_dict() # Check if there any metrics listed in the configuration if ( CONTROLLER_METRICS_KEY not in self.trainer_controller_config @@ -399,8 +402,24 @@ def on_init_end( ), operation_actions=[], ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL if CONTROLLER_CONFIG_KEY in controller: control.config = controller[CONTROLLER_CONFIG_KEY] + config_log_level_str = control.config.get( + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str + ) + if config_log_level_str not in log_levels: + logger.warning( + "Incorrect trigger log-level [%s] specified in the config." + " Defaulting to 'debug' level", + config_log_level_str, + ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL + else: + control.config = DEFAULT_CONFIG + control.config[CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL] = log_levels[ + config_log_level_str + ] if CONTROLLER_PATIENCE_CONFIG_KEY in controller: control.patience = PatienceControl( **controller[CONTROLLER_PATIENCE_CONFIG_KEY] diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index baa220c1..6e6d764f 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -3,6 +3,11 @@ import inspect import re +# Third Party +from transformers.utils import logging + +logger = logging.get_logger(__name__) + class Operation(metaclass=abc.ABCMeta): """Base class for operations""" @@ -32,15 +37,27 @@ def validate(self, action: str) -> bool: """ return action in self.valid_actions - def act(self, action: str, **kwargs): + def act( + self, action: str, event_name: str, control_name: str, log_level: int, **kwargs + ): """Validates the action and invokes it. Args: action: str. String depicting the action. + event_name: str. Event name triggering the act. + control_name: str. Name of the controller defining the act. + log_level: int. Log level for triggering the log. kwargs: List of arguments (key, value)-pairs. """ if not self.validate(action): raise ValueError(f"Invalid operation {action}") + logger.log( + log_level, + "Taking [%s] action in controller [%s] triggered at event [%s]", + action, + control_name, + event_name, + ) self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: From f57ff63650ba139d6e0471d244df4a70e4b13d0b Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 31 Jul 2024 17:08:52 -0600 Subject: [PATCH 4/6] limit peft deps until investigate (#274) Signed-off-by: Anh-Uong --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3438ecfb..e4bce3e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "tokenizers>=0.13.3,<1.0", "tqdm>=4.66.2,<5.0", "trl>=0.9.3,<1.0", -"peft>=0.8.0,<0.13", +"peft>=0.8.0,<0.12", "datasets>=2.15.0,<3.0", "fire>=0.5.0,<1.0", "simpleeval>=0.9.13,<1.0", From 3439a681d9257f429e0bed12cef37504cd2a316c Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Wed, 31 Jul 2024 17:23:18 -0600 Subject: [PATCH 5/6] Data custom collator (#260) * refactor code to preprocess datasets Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * fix formatting Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * allow input/output in validate args Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * format input/output JSON and mask Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * function to return suitable collator Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * add tests for SFT Trainer input/output format Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * remove unused functions Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * add eos token to input/output format Signed-off-by: Sukriti-Sharma4 * fix tests Signed-off-by: Sukriti-Sharma4 * improve docstrings Signed-off-by: Sukriti-Sharma4 * keeping JSON keys constant Signed-off-by: Sukriti-Sharma4 * support for input/output format Signed-off-by: Sukriti-Sharma4 * formatting fixes Signed-off-by: Sukriti-Sharma4 * update rEADME formats Signed-off-by: Sukriti-Sharma4 * formatting README Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 Co-authored-by: Alex-Brooks --- README.md | 22 ++- tests/test_sft_trainer.py | 56 ++++++ tests/utils/test_preprocessing_utils.py | 166 ++++++++++-------- tuning/sft_trainer.py | 10 +- tuning/utils/preprocessing_utils.py | 216 ++++++++---------------- 5 files changed, 243 insertions(+), 227 deletions(-) diff --git a/README.md b/README.md index ac65f1c2..a37abc1c 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,11 @@ pip install fms-hf-tuning[fms-accel] `fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration). ## Data format -We support two data formats: +We support the following data formats: -1. #### Pre-process the JSON/JSONL dataset +### 1. JSON formats with a single sequence and a specified response_template to use for masking on completion. + +#### 1.1 Pre-process the JSON/JSONL dataset Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. ```python @@ -87,7 +89,7 @@ The same way can be applied to any dataset, with more info can be found [here](h Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer. -2. #### Format JSON/JSONL on the fly +#### 1.2 Format JSON/JSONL on the fly Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-". @@ -101,8 +103,20 @@ data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. +##### In conclusion, if using the reponse_template and single sequence, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. + +### 2. JSONL with input and output fields (no response template) + + Pass a JSONL containing fields "input" with source text and "output" with class labels. Pre-format the input as you see fit. The output field will simply be concatenated to the end of input to create single sequence, and input will be masked. + + The "input" and "output" field names are mandatory and cannot be changed. -##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. +Example: Train.jsonl + +``` +{"input": "### Input: Colorado is a state in USA ### Output:", "output": "USA : Location"} +{"input": "### Input: Arizona is also a state in USA ### Output:", "output": "USA : Location"} +``` ## Supported Models diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 7c96ccce..b01c216c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -35,6 +35,7 @@ MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, TWITTER_COMPLAINTS_JSON_FORMAT, ) @@ -724,3 +725,58 @@ def test_run_with_good_experimental_metadata(): additional_callbacks=[TrainerCallback()], exp_metadata=metadata, ) + + +### Tests for pretokenized data +def test_pretokenized_dataset(): + """Ensure that we can provide a pretokenized dataset with input/output format.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + data_args.response_template = None + data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + _validate_training(tempdir) + + +@pytest.mark.parametrize( + "dataset_text_field,response_template", + [ + ("foo", None), + (None, "bar"), + ], +) +def test_pretokenized_dataset_bad_args(dataset_text_field, response_template): + """Ensure that we can't provide only dataset text field / response template for pretok data.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = dataset_text_field + data_args.response_template = response_template + data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT + # We should raise an error since we should not have a dataset text + # field or a response template if we have pretokenized data + with pytest.raises(ValueError): + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + +def test_pretokenized_dataset_wrong_format(): + """Ensure that we fail to generate data if the data is in the wrong format.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + data_args.response_template = None + data_args.training_data_path = TWITTER_COMPLAINTS_DATA + + # It would be best to handle this in a way that is more understandable; we might + # need to directly add validation prior to the dataset generation since datasets + # is essentially swallowing a KeyError here. + with pytest.raises(ValueError): + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) diff --git a/tests/utils/test_preprocessing_utils.py b/tests/utils/test_preprocessing_utils.py index e24cf710..9d0e1519 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/utils/test_preprocessing_utils.py @@ -18,7 +18,7 @@ from tuning.utils.preprocessing_utils import ( combine_sequence, format_dataset, - get_data_trainer_kwargs, + get_data_collator, get_formatted_dataset_with_single_sequence, get_preprocessed_dataset, load_hf_dataset_from_jsonl_file, @@ -42,6 +42,24 @@ def test_combine_sequence(input_element, output_element, expected_res): assert comb_seq == expected_res +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence_adds_eos(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) + expected_res += tokenizer.eos_token + assert isinstance(comb_seq, str) + assert comb_seq == expected_res + + # Tests for loading the dataset from disk def test_load_hf_dataset_from_jsonl_file(): input_field_name = "Tweet text" @@ -108,80 +126,53 @@ def test_get_preprocessed_dataset(max_sequence_length): assert key_lengths.pop() <= max_sequence_length -# Tests for fetching train args @pytest.mark.parametrize( - "use_validation_data, collator_type, packing", + "packing, response_template, formatted_train_dataset, max_seq_length, expected_collator", [ - (True, None, True), - (False, None, True), - (True, DataCollatorForCompletionOnlyLM, False), - (False, DataCollatorForCompletionOnlyLM, False), + ( + False, + "\n### Label:", + load_hf_dataset_from_jsonl_file( + TWITTER_COMPLAINTS_DATA, + input_field_name="Tweet text", + output_field_name="text_label", + ), + 1024, + DataCollatorForCompletionOnlyLM, + ), + ( + False, + None, + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + 1024, + DataCollatorForSeq2Seq, + ), ], ) -def test_get_trainer_kwargs_with_response_template_and_text_field( - use_validation_data, collator_type, packing +def test_get_data_collator( + packing, + response_template, + formatted_train_dataset, + max_seq_length, + expected_collator, ): - training_data_path = TWITTER_COMPLAINTS_DATA - validation_data_path = training_data_path if use_validation_data else None - # Expected columns in the raw loaded dataset for the twitter data - column_names = set(["Tweet text", "ID", "Label", "text_label", "output"]) - trainer_kwargs = get_data_trainer_kwargs( - training_data_path=training_data_path, - validation_data_path=validation_data_path, - packing=packing, - response_template="\n### Label:", - max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), - dataset_text_field="output", + """Ensure that the correct collator type is fetched based on the data args""" + collator = get_data_collator( + packing, + response_template, + AutoTokenizer.from_pretrained(MODEL_NAME), + formatted_train_dataset, + max_seq_length, ) - assert len(trainer_kwargs) == 3 - # If we are packing, we should not have a data collator - if collator_type is None: - assert trainer_kwargs["data_collator"] is None - else: - assert isinstance(trainer_kwargs["data_collator"], collator_type) - - # We should only have a validation dataset if one is present - if validation_data_path is None: - assert trainer_kwargs["eval_dataset"] is None - else: - assert isinstance(trainer_kwargs["eval_dataset"], Dataset) - assert set(trainer_kwargs["eval_dataset"].column_names) == column_names - - assert isinstance(trainer_kwargs["train_dataset"], Dataset) - assert set(trainer_kwargs["train_dataset"].column_names) == column_names - - -@pytest.mark.parametrize("use_validation_data", [True, False]) -def test_get_trainer_kwargs_with_custom_masking(use_validation_data): - training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT - validation_data_path = training_data_path if use_validation_data else None - # Expected columns in the raw loaded dataset for the twitter data - column_names = set(["input_ids", "attention_mask", "labels"]) - trainer_kwargs = get_data_trainer_kwargs( - training_data_path=training_data_path, - validation_data_path=validation_data_path, - packing=False, - response_template=None, - max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), - dataset_text_field=None, - ) - assert len(trainer_kwargs) == 4 - # If we are packing, we should not have a data collator - assert isinstance(trainer_kwargs["data_collator"], DataCollatorForSeq2Seq) - - # We should only have a validation dataset if one is present - if validation_data_path is None: - assert trainer_kwargs["eval_dataset"] is None - else: - assert isinstance(trainer_kwargs["eval_dataset"], Dataset) - assert set(trainer_kwargs["eval_dataset"].column_names) == column_names - - assert isinstance(trainer_kwargs["train_dataset"], Dataset) - assert set(trainer_kwargs["train_dataset"].column_names) == column_names - # Needed to sidestep TRL validation - assert trainer_kwargs["formatting_func"] is not None + assert isinstance(collator, expected_collator) # Tests for validating data args @@ -197,6 +188,14 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): ), False, ), + # data formatter with no response template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + data_formatter_template="### Input: {{input}} \n\n### Response: {{output}}", + ), + False, + ), # response template with no dataset_text_field or formatter ( configs.DataArguments( @@ -205,9 +204,17 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): ), False, ), + # JSON without input / output for no single sequence arguments + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + ), + False, + ), ], ) def test_validate_args(data_args, packing): + """Ensure that respective errors are thrown for incorrect data arguments""" with pytest.raises(ValueError): validate_data_args(data_args, packing) @@ -255,12 +262,27 @@ def test_get_formatted_dataset_with_single_sequence( data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", ) ), + # input/output JSON with masking on input + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + ) + ), ], ) def test_format_dataset(data_args): + """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - train_set, eval_set, dataset_text_field = format_dataset(data_args, tokenizer) + train_set, eval_set, dataset_text_field = format_dataset( + data_args, tokenizer, max_seq_length=1024 + ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) - assert dataset_text_field in train_set.column_names - assert dataset_text_field in eval_set.column_names + if dataset_text_field is None: + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(eval_set.column_names) == column_names + assert set(train_set.column_names) == column_names + else: + assert dataset_text_field in train_set.column_names + assert dataset_text_field in eval_set.column_names diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 30095b98..d889c67e 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -268,8 +268,14 @@ def train( formatted_train_dataset, formatted_validation_dataset, dataset_text_field, - ) = format_dataset(data_args, tokenizer) - data_collator = get_data_collator(packing, data_args.response_template, tokenizer) + ) = format_dataset(data_args, tokenizer, max_seq_length) + data_collator = get_data_collator( + packing, + data_args.response_template, + tokenizer, + formatted_train_dataset, + max_seq_length, + ) if framework is not None and framework.requires_agumentation: model, (peft_config,) = framework.augmentation( diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py index 88db911a..e3dbdb37 100644 --- a/tuning/utils/preprocessing_utils.py +++ b/tuning/utils/preprocessing_utils.py @@ -28,6 +28,10 @@ logger = logging.get_logger("sft_trainer_preprocessing") +# In future we may make the fields configurable +JSON_INPUT_KEY = "input" +JSON_OUTPUT_KEY = "output" + def validate_data_args(data_args: configs.DataArguments, packing: bool): @@ -36,20 +40,14 @@ def validate_data_args(data_args: configs.DataArguments, packing: bool): ), "Training data path has to be set and str" # Dataset containing single sequence needs a response template for masking - if data_args.response_template is None and data_args.dataset_text_field is not None: - if packing is False: - raise ValueError( - "Since dataset_text_field is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) - - # Currently if packing is false, we require a response_template. This may change in future. - if packing is False: + if data_args.dataset_text_field or data_args.data_formatter_template: if data_args.response_template is None: - raise ValueError( - "Response template is None, needs to be set for training \ - with packing disabled." - ) + if packing is False: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) if data_args.response_template: # To use Response template, pass datasets with single sequence instances \ @@ -65,16 +63,32 @@ def validate_data_args(data_args: configs.DataArguments, packing: bool): "dataset_text_field and data_formatter_template are both set,\ but are mutually exclusive options" ) - # TODO(s) In future seupport two more formats: - # 1. Allow no response template, and JSON with input/output fields and mask input - # 2. Allow pretokenized Dataset besides JSON. + # If not single sequence, JSON should contain input/output fields + if not (data_args.dataset_text_field or data_args.data_formatter_template): + json_dataset = datasets.load_dataset( + "json", data_files=data_args.training_data_path + ) + if JSON_INPUT_KEY not in json_dataset["train"].column_names: + raise ValueError( + "JSON should contain input field if no dataset_text_field or \ + data_formatter_template specified" + ) + if JSON_OUTPUT_KEY not in json_dataset["train"].column_names: + raise ValueError( + "JSON should contain output field if no dataset_text_field or \ + data_formatter_template specified" + ) + # TODO(s) In future support + # Allow pretokenized Dataset besides JSON. def get_data_collator( packing: bool, response_template: Optional[str], tokenizer: AutoTokenizer, + formatted_train_dataset: Dataset, + max_seq_length: int, ) -> Callable: """Create and return the the appropriate collator type based on the configuration for packing, response_template, and dataset_text_field. @@ -86,6 +100,10 @@ def get_data_collator( Response template to be used for formatting by TRL. tokenizer: AutoTokenizer Loaded tokenizer object to be used by the collator. + formatted_train_dataset: Dataset + Train Dataset formatted for tuning + max_seq_length: int + Max sequence length expected Returns: Callable @@ -105,25 +123,29 @@ def get_data_collator( tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX, ) - # TO DO with future changes, - # 1. Support no packing and seq2seq colator without response template - # # if dataset_text_field is None and response_template is None: - # # Use the seq2seq data collator; - # # Note that this automatically pads labels with -100 - # return DataCollatorForSeq2Seq( - # tokenizer=tokenizer, padding=True, max_length=max_sequence_length - # ) - # 2. add anything needed for preprocessed input + # Note that this automatically pads labels with -100 + # TODO check if this is sufficient for preprocessed + if ( + "attention_mask" in formatted_train_dataset.column_names + and "labels" in formatted_train_dataset.column_names + ): + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_seq_length + ) raise ValueError( "Could not pick a data collator. Please refer to supported data formats" ) -def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): +def format_dataset( + data_args: configs.DataArguments, tokenizer: AutoTokenizer, max_seq_length: int +): """ Args: data_args: tuning.config.configs.DataArguments tokenizer: AutoTokenizer + max_seq_length: int + Max sequence length expected Returns: Tuple(Dataset, Dataset, str) tuple containing train_dataset, eval_dataset and dataset_text_field @@ -148,132 +170,25 @@ def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): data_args.data_formatter_template, ) logger.info("Validation dataset length is %s", len(eval_dataset)) - # TODO: add a else here for preprocessing - return train_dataset, eval_dataset, dataset_text_field - - -################################################################################### -### The functions below are not yet used. Iterative development towards new features - - -def get_data_collator_temp( - packing: bool, - dataset_text_field: Optional[str], - response_template: Optional[str], - max_sequence_length: int, - tokenizer: AutoTokenizer, -) -> Callable: - """Create and return the the appropriate collator type based on the configuration for packing, - response_template, and dataset_text_field. - - Args: - packing: bool - Whether or not we should apply packing or not. - dataset_text_field: Optional[str] - Dataset text field fto be used for formatting by TRL. - response_template: Optional[str] - Response template to be used for formatting by TRL. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - - Returns: - Callable - Callable collator to be leveraged by the trainer. - """ - if not packing: - if dataset_text_field is None and response_template is None: - # Use the seq2seq data collator; note that this automatically pads labels with -100 - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_sequence_length - ) - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - - -def get_data_trainer_kwargs( - training_data_path: str, - validation_data_path: str, - packing: bool, - response_template: Optional[str], - max_sequence_length: int, - tokenizer: AutoTokenizer, - dataset_text_field: Optional[str], -) -> Dict[str, Any]: - """Get trainer args related to data / processing. At the moment, this consists of: - - the training dataset - - the evaluation dataset - - the data collator - - Maybe a formatting a function [only for a special case for validation] - The result can be kwarg expanded into the trainer initialization. - - Args: - training_data_path: str - Path to the training data. - validation_data_path: str - Path to the validation data. - packing: bool - Whether or not we should apply packing or not. - response_template: Optional[str] - Response template to be used for formatting by TRL. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - dataset_text_field: Optional[str] - Dataset text field fto be used for formatting by TRL. - - Returns: - Dict[str, Any] - Data related kwargs to be used by the SFT Trainer. - """ - data_collator = get_data_collator_temp( - packing, dataset_text_field, response_template, max_sequence_length, tokenizer - ) - eval_dataset = None - data_kwargs = {} - if isinstance(data_collator, DataCollatorForSeq2Seq): - # HACK: This function is never called, but is needed to sidestep TRL's internal validation. - data_kwargs["formatting_func"] = lambda x: x + else: + # This is for JSON containing input/output fields train_dataset = get_preprocessed_dataset( - training_data_path, + data_args.training_data_path, tokenizer, - max_sequence_length, - input_field_name="input", - output_field_name="output", + max_seq_length, + input_field_name=JSON_INPUT_KEY, + output_field_name=JSON_OUTPUT_KEY, ) - if validation_data_path: + if data_args.validation_data_path: eval_dataset = get_preprocessed_dataset( - validation_data_path, + data_args.validation_data_path, tokenizer, - max_sequence_length, - input_field_name="input", - output_field_name="output", - ) - else: - train_dataset = get_formatted_dataset_with_single_sequence( - training_data_path, dataset_text_field, tokenizer - ) - if validation_data_path: - eval_dataset = get_formatted_dataset_with_single_sequence( - validation_data_path, dataset_text_field, tokenizer + max_seq_length, + input_field_name=JSON_INPUT_KEY, + output_field_name=JSON_OUTPUT_KEY, ) - data_kwargs["data_collator"] = data_collator - data_kwargs["train_dataset"] = train_dataset - data_kwargs["eval_dataset"] = eval_dataset - return data_kwargs + return train_dataset, eval_dataset, dataset_text_field def get_formatted_dataset_with_single_sequence( @@ -396,7 +311,7 @@ def get_jsonl_object(): ### Utils for custom masking / manipulating input / output strs, etc -def combine_sequence(input_element: str, output_element: str): +def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): """Combines / concatenates input & output element. Args: @@ -404,6 +319,9 @@ def combine_sequence(input_element: str, output_element: str): Input component of the combined sequence. output_element: str Output component of the combined sequence. + eos_token: str + EOS token associated with the tokenizer. \ + If passed, it will be concatenated at end Returns: str @@ -412,8 +330,8 @@ def combine_sequence(input_element: str, output_element: str): if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( (" ", "\n", "\t") ): - return input_element + " " + output_element - return input_element + output_element + return input_element + " " + output_element + eos_token + return input_element + output_element + eos_token def preprocess_and_tokenize( @@ -445,7 +363,7 @@ def preprocess_and_tokenize( Dictionary containing the input IDs/labels/attention mask for this record. """ combined_seq = combine_sequence( - element[input_field_name], element[output_field_name] + element[input_field_name], element[output_field_name], tokenizer.eos_token ) tokenized_comb_seqs = tokenizer( From c7589d17f625565d1175412302e1737aa88876f1 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 31 Jul 2024 18:13:20 -0600 Subject: [PATCH 6/6] Revert "limit peft deps until investigate (#274)" (#275) This reverts commit f57ff63650ba139d6e0471d244df4a70e4b13d0b. Signed-off-by: Anh-Uong --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e4bce3e5..3438ecfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "tokenizers>=0.13.3,<1.0", "tqdm>=4.66.2,<5.0", "trl>=0.9.3,<1.0", -"peft>=0.8.0,<0.12", +"peft>=0.8.0,<0.13", "datasets>=2.15.0,<3.0", "fire>=0.5.0,<1.0", "simpleeval>=0.9.13,<1.0",