Skip to content

Commit

Permalink
Let Huggingface Properly Initialize Arguments, and Fix FSDP-LORA Chec…
Browse files Browse the repository at this point in the history
…kpoint-Saves and Resumption (foundation-model-stack#53)

* training args should call post init to intialize all HF flags

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* remove run_distribtued flag and peft_saving callback

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* revert deletion of validation checks on some train args

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* revert the addition of __post_init__ as it is actually not needed

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Co-authored-by: Sukriti Sharma <[email protected]>
  • Loading branch information
3 people authored Mar 9, 2024
1 parent 3f83a3d commit 0729820
Showing 1 changed file with 2 additions and 22 deletions.
24 changes: 2 additions & 22 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@
from tuning.utils.data_type_utils import get_torch_dtype


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_path)

if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -118,7 +107,6 @@ def train(
None for fine tuning
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

logger = logging.get_logger("sft_trainer")

Expand All @@ -132,11 +120,6 @@ def train(
):
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

# make sure to unset FSDP args when running on single gpu
if not run_distributed:
train_args.fsdp = ""
train_args.fsdp_config = {"xla": False}

task_type = "CAUSAL_LM"
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand All @@ -147,8 +130,6 @@ def train(

peft_config = get_hf_peft_config(task_type, peft_config)

model.gradient_checkpointing_enable()

# TODO: Move these to a config as well
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
Expand Down Expand Up @@ -239,8 +220,7 @@ def train(

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]
callbacks = [aim_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
Expand Down Expand Up @@ -281,7 +261,7 @@ def train(
peft_config=peft_config,
)

if run_distributed and peft_config is not None:
if trainer.is_fsdp_enabled and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
Expand Down

0 comments on commit 0729820

Please sign in to comment.