Skip to content

Commit

Permalink
Merge pull request #276 from jbusche/jb-releasev1.1.0
Browse files Browse the repository at this point in the history
Releasev1.1.0
  • Loading branch information
anhuong authored Aug 1, 2024
2 parents 4ed1bc4 + 8b11924 commit ab3b331
Show file tree
Hide file tree
Showing 13 changed files with 433 additions and 110 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ python tuning/sft_trainer.py \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16 \
--target_modules ["c_attn", "c_proj"]
--target_modules c_attn c_proj
```

Equally you can pass in a JSON configuration for running tuning. See [build doc](./build/README.md) for more details. The above can also be passed in as JSON:
Expand Down Expand Up @@ -547,7 +547,7 @@ Trainer controller is a framework for controlling the trainer loop using user-de

This framework helps users define rules to capture scenarios like criteria for stopping an ongoing training (E.g validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing etc).

For details about how you can use set a custom stopping criteria and perform custom operations, see [examples/trainer_controller/README.md](examples/trainer_controller/README.md)
For details about how you can use set a custom stopping criteria and perform custom operations, see [examples/trainercontroller_configs/Readme.md](examples/trainercontroller_configs/Readme.md)

## More Examples

Expand Down
106 changes: 95 additions & 11 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@
import tempfile
import shutil
from pathlib import Path
import json

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16

# Local
from build.utils import (
process_accelerate_launch_args,
serialize_args,
get_highest_checkpoint,
copy_checkpoint,
)
from tuning.utils.config_utils import get_json_config
from tuning.config.tracker_configs import FileLoggingTrackerConfig
Expand All @@ -43,10 +48,18 @@
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.data import tokenizer_data_utils

ERROR_LOG = "/dev/termination-log"


def get_base_model_from_adapter_config(adapter_config):
"""Given path to adapter_config.json file, returns the base model name"""
with open(adapter_config, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
return adapter_config.get("base_model_name_or_path")


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
Expand Down Expand Up @@ -117,18 +130,89 @@ def main():
sys.exit(INTERNAL_ERROR_EXIT_CODE)

try:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(tempdir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
shutil.copytree(
os.path.join(tempdir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
last_checkpoint_dir = get_highest_checkpoint(tempdir)
last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(
last_checkpoint_path, "adapter_config.json"
)
tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path)

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(
adapter_config_path
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
copy_checkpoint_bool = True
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

copy_checkpoint_bool = False
logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", original_output_dir)
model.save_pretrained(original_output_dir)
# save tokenizer with model
tokenizer.save_pretrained(original_output_dir)

# copy last checkpoint into mounted output dir
if copy_checkpoint_bool:
logging.info(
"Copying last checkpoint %s into output dir %s",
last_checkpoint_dir,
original_output_dir,
)
copy_checkpoint(last_checkpoint_path, original_output_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
Expand Down
17 changes: 17 additions & 0 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
# Third Party
import torch
from accelerate.commands.launch import launch_command_parser
import shutil


def copy_checkpoint(source, destination):
if not os.path.exists(destination):
os.makedirs(destination)
shutil.copystat(source, destination)
# Have a list of directory objects, now iterate over them.
for item in os.listdir(source):
source_file = os.path.join(source, item)
destination_file = os.path.join(destination, item)
if os.path.isdir(source_file):
# recursive call for subdirectories
copy_checkpoint(source_file, destination_file)
else:
# straight copy.
shutil.copy2(source_file, destination_file)


def get_highest_checkpoint(dir_path):
Expand Down
24 changes: 12 additions & 12 deletions examples/trainercontroller_configs/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ Trainer controller is a framework for controlling the trainer loop using user-de

### Motivation

This frameworks helps user define rules to capture scenarios like criteria for stopping an ongoing training (E.g validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing etc).
This frameworks helps user define rules to capture scenarios like criteria for stopping an ongoing training (e.g., validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing, etc).

### Usage
*Note: Evaluation loss and validation loss are the same.*
1. The trainer controller feature can be used and its behavior is controlled by a configuration file (we will illustrate the configuration file below) supplied by the user at the start of the training. Here is a sample of how the user can initiate a trainer controller for a training job, by specifying path to an existing configuration `loss.yaml` in the `./examples/trainercontroller_configs` directory using the flag `--trainer_controller_config_file`:
1. The trainer controller feature can be controlled by a configuration file supplied by the user at the start of the training. Here is a sample of how the user can initiate a trainer controller for a training job, by specifying path to an existing configuration `loss.yaml` in the `./examples/trainercontroller_configs` directory using the flag `--trainer_controller_config_file`:
```shell
python ./tuning/sft_trainer.py \
...
Expand All @@ -32,15 +32,15 @@ This frameworks helps user define rules to capture scenarios like criteria for s
operations:
- hfcontrols.should_training_stop
```
Here is a brief primer on the above configuration. More details could be found [here](./architecture_records/001-trainer-controller-framework.md).
Here is a brief primer on the above configuration. More details could be found [here](../../architecture_records/001-trainer-controller-framework.md). Note that in the following descriptions, we use `metric` and `metric handler` interchangeably to describe a class which exposes numeric information about the training state / relevant computations for use in a `rule` for early termination.
- *Description:* The above configuration stops the training when a **evaluation loss** decreases below 2.25 after two epochs.
- *Metrics:* The configuration uses two metrics listed under `controller-metrics` section. One is named `evalmetric`, which uses an in-built metric class called `EvalMetrics` to expose evaluation loss and the other (`trainer_state`) uses `TrainingState` to expose the current epoch. These are referred to in the `rule` as shown above. There are other metrics also which could be used in place of `evalmetric` and . Here is a list of supported metric classes:
- `Loss`: Exposes the **training loss** after every `on_log` event. See more on trainer events [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback).
- `TrainerState`: This metric exposes the **trainer state** (more on trainer state can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerState)). [Here](tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml) is an example metric which uses both the `TrainerState` and `Loss` metric.
- `EvalMetrics`: This metric exposes all the evaluation metrics used in the training job (E.g evaluation/validation loss). [Here](tests/data/trainercontroller/exposed_metrics.yaml) is an example metric which uses both the `EvalMetrics`.
- *Metrics:* The configuration uses two metrics listed under `controller_metrics` section. One is named `evalmetric`, which uses the built-in metric handler class `EvalMetrics` to expose evaluation loss, and the other, `trainer_state`, uses the built-in metric handler class `TrainingState` to expose the current epoch. These are referred to in the `rule` as shown above. There are other metrics that could also be used in place of `evalmetric` and `trainer_state`. At the time of writing, the supported metric handler classes are as follows:
- `Loss`: This metric handler exposes the **training loss** after every `on_log` event. See more on trainer events [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback).
- `TrainerState`: This metric exposes the **trainer state** (more on trainer state can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerState)). [Here](../../tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml) is an example metric which uses both the `TrainerState` and `Loss` metric.
- `EvalMetrics`: This metric exposes all the evaluation metrics used in the training job (E.g evaluation/validation loss). [Here](../../tests/data/trainercontroller/exposed_metrics.yaml) is an example config which uses the `EvalMetric`'s `eval_loss`.
- `HistoryBasedMetric`: This metric exposes a moving **window** of evaluation metrics and training loss. It is useful to create rules on a history of values (i.e. evaluation metrics and training loss). Following are some examples which illustrate how this metric could be used:
- [epoch-level-eval-loss-patience.yaml](tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml): This configuration performs a threshold test for evaluation loss with a **patience threshold** of 2. I.e suppose the evaluation loss lower threshold is 2, and patience threshold is 3, then the trainer controller will not take an action (E.g. stop training) when the rule becomes true (i.e. evaluation loss is lower than 2) for for three consecutive times.
- [non-decreasing-training-loss.yaml](tests/data/trainercontroller/non-decreasing-training-loss.yaml): This configuration compares the first and last values of a window of training loss samples and determines if the training loss has increased or not. If there is an increase, the training is stopped.
- [epoch-level-eval-loss-patience.yaml](../../tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml): This configuration performs a threshold test for evaluation loss with a **patience threshold** of 2. I.e., suppose the evaluation loss lower threshold is 2, and patience threshold is 3, then the trainer controller will not take an action, e.g., stop the training, when the rule becomes true. i.e., evaluation loss is lower than 2, three consecutive times.
- [non-decreasing-training-loss.yaml](../../tests/data/trainercontroller/non-decreasing-training-loss.yaml): This configuration compares the first and last values of a window of training loss samples and determines if the training loss has increased or not. If there is an increase, the training is stopped.
Let us assume use the below example to understand the usage:
```yaml
Expand Down Expand Up @@ -80,10 +80,10 @@ This frameworks helps user define rules to capture scenarios like criteria for s
```
1. To access the first value in window of evaluation metric `eval_loss`, here is the illustration `history_window["metrics"]["eval_loss"][0]`. In the above YAML, the last element is accessed as follows: `history_window["metrics"]["eval_loss"][-1]`.
1. Similarly, the `history_window["metrics"]["global_step"][0]` is global_step at the time of generation of this evaluation metric and `history_window["metrics"]["epoch"][0]` is the corresponding epoch.
1. Similar approach is followed to access training loss (i.e. `history_window["training_loss"]["loss"][0]` givest the first training loss).
1. A similar approach is followed to access training loss (i.e., `history_window["training_loss"]["loss"][0]` gives the first training loss).
- *Trigger:* There is also a trigger event to decide when the `rule` needs to be evaluated. This event has to be one of the trainer events listed [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback).
- *Rule:* The `rule` is a python statement which could use the metric name (e.g. `loss` in the above case) to define conditions which, when satisfied (it is a boolean condition and should evaluate to True to be satisfied) will trigger the operation(s) listed in `operations`.
- *Trigger:* There is also a trigger event to decide *when* the `rule` needs to be evaluated. This event has to be one of the trainer events listed [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback). The choice of even to trigger on allows for more control, e.g., controlling the times at which we should consider early training termination.
- *Rule:* The `rule` is a python statement which could use the metric name, e.g., `loss` in the above case, to define boolean conditions which, when satisfied, will trigger the operation(s) listed in `operations`.
- *Operation:* The `operations` section lists the operations that could be performed when the `rule` is satisfied (i.e. condition becomes True). Currently, we support only one type of operation class `HFControls` (In this particular example, the class and corresponding operation name `hfcontrols` are not specified explicitly as they are considered default and can be omitted). The `HFControls` class supports all operations listed below. More on these operations can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerControl).
- `hfcontrols.should_training_stop`: Stops the training.
- `hfcontrols.should_epoch_stop`: Interrupts the current epoch.
Expand Down
Loading

0 comments on commit ab3b331

Please sign in to comment.