Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[peft] If AutoModel is wrapped with PEFT for prompt learning, then extend the attention mask #3000

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

tomaarsen
Copy link
Collaborator

Resolves #2995, resolves huggingface/peft#2154

Hello!

Pull Request overview

  • If AutoModel is wrapped with PEFT for prompt learning, then extend the attention mask

Details

Sentence Transformer models are sometimes trained with the AutoModel wrapped in PEFT, as that can lead to decreased computation cost while training. In particular, when PEFT with prompt learning is used, then virtual tokens (or rather, just input_embeds) are added to the model, and the attention_mask is updated before the base AutoModel is called.

However, then the attention mask used in the Pooling module won't be updated. This PR fixes that.

Concern

My primary concern now is that the model doesn't seem to be able to train well:

import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
cls_pooling = False
if cls_pooling:
    transformer = Transformer("microsoft/mpnet-base")
    pooling = Pooling(transformer.get_word_embedding_dimension(), "cls")
    model = SentenceTransformer(
        modules=[transformer, pooling],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="MPNet base trained on GooAQ triplets",
        ),
    )
else:
    model = SentenceTransformer(
        "microsoft/mpnet-base",
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="MPNet base trained on GooAQ triplets",
        ),
    )

# Apply PEFT with PromptTuningConfig
peft_config = PromptTuningConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    num_virtual_tokens=1
)
model[0].auto_model = get_peft_model(model[0].auto_model, peft_config)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train")
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-peft"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=250,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
# corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries} |
    {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name, private=True)

Regardless of whether I use mean or CLS pooling.

@BenjaminBossan could you 1) verify that the PR diff looks solid at a glance and 2) let me know if a model with this config is supposed to train roughly as well as with "full" training?

peft_config = PromptTuningConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    num_virtual_tokens=1
)
  • Tom Aarsen

@BenjaminBossan
Copy link

Thanks a lot for quickly providing this solution to the issue. Regarding the question whether this should train well, I honestly don't have much experience with the prompt learning methods, so I can't say what I would expect. The fix itself looks correct to me, so I would proceed even if training does not work well at the moment.

Perhaps @VecherVhatuX or @mruniverse8 could give this branch a try and report back their results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants