Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Jul 30, 2024
2 parents ed6d7a9 + 537215f commit 3b7ff6b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 15 deletions.
103 changes: 94 additions & 9 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
import tempfile
import shutil
from pathlib import Path
import json

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16

# Local
from build.utils import (
Expand All @@ -44,10 +48,18 @@
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.data import tokenizer_data_utils

ERROR_LOG = "/dev/termination-log"


def get_base_model_from_adapter_config(adapter_config):
"""Given path to adapter_config.json file, returns the base model name"""
with open(adapter_config, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
return adapter_config.get("base_model_name_or_path")


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
Expand Down Expand Up @@ -118,16 +130,89 @@ def main():
sys.exit(INTERNAL_ERROR_EXIT_CODE)

try:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(tempdir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
copy_checkpoint(
os.path.join(tempdir, pt_checkpoint_dir), original_output_dir
last_checkpoint_dir = get_highest_checkpoint(tempdir)
last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(
last_checkpoint_path, "adapter_config.json"
)
tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path)

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(
adapter_config_path
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
copy_checkpoint_bool = True
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

copy_checkpoint_bool = False
logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", original_output_dir)
model.save_pretrained(original_output_dir)
# save tokenizer with model
tokenizer.save_pretrained(original_output_dir)

# copy last checkpoint into mounted output dir
if copy_checkpoint_bool:
logging.info(
"Copying last checkpoint %s into output dir %s",
last_checkpoint_dir,
original_output_dir,
)
copy_checkpoint(last_checkpoint_path, original_output_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
Expand Down
11 changes: 5 additions & 6 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory):
C = meta["config"]
T = meta["tracker"]

if tracker_configs is not None:
_conf = _get_tracker_config_by_name(name, tracker_configs)
if _conf is not None:
config = C(**_conf)
else:
config = C()
_conf = _get_tracker_config_by_name(name, tracker_configs)
if _conf is not None:
config = C(**_conf)
else:
config = C()
return T(config)

0 comments on commit 3b7ff6b

Please sign in to comment.