Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Aug 1, 2024
2 parents 9e41847 + c7589d1 commit a566688
Show file tree
Hide file tree
Showing 27 changed files with 370 additions and 274 deletions.
22 changes: 18 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ pip install fms-hf-tuning[fms-accel]
`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration).

## Data format
We support two data formats:
We support the following data formats:

1. #### Pre-process the JSON/JSONL dataset
### 1. JSON formats with a single sequence and a specified response_template to use for masking on completion.

#### 1.1 Pre-process the JSON/JSONL dataset
Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code.

```python
Expand Down Expand Up @@ -87,7 +89,7 @@ The same way can be applied to any dataset, with more info can be found [here](h

Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer.

2. #### Format JSON/JSONL on the fly
#### 1.2 Format JSON/JSONL on the fly
Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template.
JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-".

Expand All @@ -101,8 +103,20 @@ data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}`

Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`.

##### In conclusion, if using the reponse_template and single sequence, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer.

### 2. JSONL with input and output fields (no response template)

Pass a JSONL containing fields "input" with source text and "output" with class labels. Pre-format the input as you see fit. The output field will simply be concatenated to the end of input to create single sequence, and input will be masked.

The "input" and "output" field names are mandatory and cannot be changed.

##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer.
Example: Train.jsonl

```
{"input": "### Input: Colorado is a state in USA ### Output:", "output": "USA : Location"}
{"input": "### Input: Arizona is also a state in USA ### Output:", "output": "USA : Location"}
```

## Supported Models

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_stop_on_training_loss_below_threshold
triggers:
- on_step_end
rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2
config:
trigger_log_level: warning
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions examples/trainercontroller_configs/loss.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
rule: training_loss["loss"] < 1.0
operations:
- hfcontrols.should_training_stop
1 change: 1 addition & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@
TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML = os.path.join(
_DATA_DIR, "thresholded-training-loss.yaml"
)
TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml")
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_custom_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
operations:
- name: custom_operation
Expand All @@ -8,6 +8,6 @@ controllers:
- name: loss_controller_custom_operation
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- custom_operation.should_perform_action_xyz
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
operations:
- name: custom_operation
Expand All @@ -8,6 +8,6 @@ controllers:
- name: loss_controller_custom_operation_invalid_action
triggers:
- on_log
rule: loss < 1.0
rule: training_loss["loss"] < 1.0
operations:
- custom_operation.should_
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: MissingMetricClass
controllers:
- name: loss_controller_invalid_metric
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_operation
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- missingop.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_operation_action
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.missingaction
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_trigger.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_trigger
triggers:
- log_it_all_incorrect_trigger_name
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_on_threshold.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
controller_metrics:
- name: state
class: TrainingState
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 2 and state["epoch"] >= 0.5
rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_unavailable_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_unavailable_metric
triggers:
- on_step_end
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_wrong_input_rule
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_wrong_os_rule
Expand Down
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/on-save.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller_metrics:
- name: state
class: TrainingState
controllers:
- name: stop_on_training_loss_on_save
triggers:
- on_save
rule: state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
56 changes: 56 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MALFORMATTED_DATA,
MODEL_NAME,
TWITTER_COMPLAINTS_DATA,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT,
TWITTER_COMPLAINTS_JSON_FORMAT,
)

Expand Down Expand Up @@ -724,3 +725,58 @@ def test_run_with_good_experimental_metadata():
additional_callbacks=[TrainerCallback()],
exp_metadata=metadata,
)


### Tests for pretokenized data
def test_pretokenized_dataset():
"""Ensure that we can provide a pretokenized dataset with input/output format."""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
data_args.dataset_text_field = None
data_args.response_template = None
data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
_validate_training(tempdir)


@pytest.mark.parametrize(
"dataset_text_field,response_template",
[
("foo", None),
(None, "bar"),
],
)
def test_pretokenized_dataset_bad_args(dataset_text_field, response_template):
"""Ensure that we can't provide only dataset text field / response template for pretok data."""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

data_args = copy.deepcopy(DATA_ARGS)
data_args.dataset_text_field = dataset_text_field
data_args.response_template = response_template
data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT
# We should raise an error since we should not have a dataset text
# field or a response template if we have pretokenized data
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)


def test_pretokenized_dataset_wrong_format():
"""Ensure that we fail to generate data if the data is in the wrong format."""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

data_args = copy.deepcopy(DATA_ARGS)
data_args.dataset_text_field = None
data_args.response_template = None
data_args.training_data_path = TWITTER_COMPLAINTS_DATA

# It would be best to handle this in a way that is more understandable; we might
# need to directly add validation prior to the dataset generation since datasets
# is essentially swallowing a KeyError here.
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
7 changes: 0 additions & 7 deletions tests/trainercontroller/custom_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
class CustomOperation(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_perform_action_xyz(self, control: TrainerControl, **_):
"""This method performs a set training stop flag action.
Expand Down
7 changes: 0 additions & 7 deletions tests/trainercontroller/custom_operation_invalid_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
class CustomOperationInvalidAction(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_(self, control: TrainerControl, **_):
"""This method defines an action within an invalid name.
Expand Down
16 changes: 16 additions & 0 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ def test_thresholded_training_loss():
assert control.should_training_stop is True


def test_thresholded_training_loss_on_save():
"""Tests the thresholded training loss example in
`examples/trainer-controller-configs/on-save.yaml`
"""
test_data = _setup_data()
tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_TEST_ON_SAVE_YAML)
control = TrainerControl(should_training_stop=False)
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.states[2], control=control
)
# Trigger rule and test the condition
tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control)
assert control.should_training_stop is True


def test_non_decreasing_training_loss():
"""Tests the non-decreasing training loss example in
`examples/trainer-controller-configs/non-decreasing-training-loss.yaml`
Expand Down
Loading

0 comments on commit a566688

Please sign in to comment.