Skip to content

Commit

Permalink
Handle cache_position for transformers 4.47.0 and later (#528) (#529)
Browse files Browse the repository at this point in the history
## Summary
Fix [issue #528](#528) by
copying the new way to handle RoPE from transformers 4.48.0

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
```python
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids, image_grid_thw, video_grid_thw, attention_mask
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
```
## Testing Done

Tested on (all worked with this PR, 4.48.0 didn't work with this PR):
pip install transformers==4.46.2
pip install transformers==4.46.3
pip install transformers==4.48.0

Before applying this PR, using training Qwen2-VL using `liger-kernel`
with `transformers>=4.47.0` would result in this error ([issue
#528](#528)):
```
Traceback (most recent call last):
  File "/workspaces/test/t.py", line 51, in <module>
    generated_ids = model.generate(**inputs, max_new_tokens=128)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2255, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lce_forward() got an unexpected keyword argument 'cache_position'
```

Inference test script:
```python
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from transformers import BitsAndBytesConfig
from qwen_vl_utils import process_vision_info

apply_liger_kernel_to_qwen2_vl()

model_id = "Qwen/Qwen2-VL-2B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)


messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```

Training test script:
```python
import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import BitsAndBytesConfig
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
import bitsandbytes as bnb
from trl import SFTTrainer, SFTConfig
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from configs_and_helpers import clear_memory, vl_format_data, generate_text_from_sample

apply_liger_kernel_to_qwen2_vl()

model_id = "Qwen/Qwen2-VL-2B-Instruct"
dataset_id = "HuggingFaceM4/ChartQA"

system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:20%]", "val[:2%]", "test[:1%]"])

train_dataset = [vl_format_data(sample, system_message) for sample in train_dataset]
eval_dataset = [vl_format_data(sample, system_message) for sample in eval_dataset]
test_dataset = [vl_format_data(sample, system_message) for sample in test_dataset]


model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

processor = Qwen2VLProcessor.from_pretrained(model_id)
print(f"{train_dataset[0]=}")
print(f"{train_dataset[0][1:2]=}")


output = generate_text_from_sample(model, processor, train_dataset[0])
print(f"{output=}")

clear_memory(globals())

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    use_cache=False
)

processor = Qwen2VLProcessor.from_pretrained(model_id)
processor.padding_side = "right"  # Ensure padding is added to the right side
processor.tokenizer.padding_side = "right"  # Ensure padding is added to the right side

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# Apply PEFT model adaptation
print(model)
model = get_peft_model(model, peft_config)
print(model)

# Print trainable parameters
model.print_trainable_parameters()

optimizer = create_loraplus_optimizer(
    model=model,
    optimizer_cls=bnb.optim.PagedAdamW8bit,
    # optimizer_cls=torch.optim.AdamW,
    lr=2e-4,
    eps=1e-6,
    # eps=1e-8,
    betas=(0.9, 0.999),
    weight_decay=0.0,
    loraplus_lr_ratio=16,
)
scheduler = None


# Configure training arguments
training_args = SFTConfig(
    output_dir="qwen2-2b-instruct-trl-sft-ChartQA",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    eval_accumulation_steps=4,
    # Logging and evaluation
    logging_steps=1,
    eval_steps=10,
    torch_empty_cache_steps=1,
    eval_strategy="steps",
    save_strategy="epoch",
    bf16=True,
    # Gradient checkpointing settings
    gradient_checkpointing_kwargs={"use_reentrant": False},
    gradient_checkpointing=True,
    # Dataset configuration
    dataset_kwargs={"skip_prepare_dataset": True},
    # max_seq_length=1024  # Maximum sequence length for input
    remove_unused_columns = False  # Keep unused columns in dataset
)

# Create a data collator to encode text and image pairs
def collator_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    # Mask padding tokens in labels
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100

    batch["labels"] = labels

    return batch

print(f"Processed data:\n{collator_fn(train_dataset[:2])}")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collator_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
    optimizers=(optimizer, scheduler)
)
trainer.train()
trainer.save_model(training_args.output_dir)

# model.save_model(output_name)
```

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->
My hardware is fairly weak, OOM running `make test`. Might need further
testing.
- Hardware Type: Nvidia RTX 4070 Laptop
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

Co-authored-by: Shao Tang <[email protected]>
  • Loading branch information
BenasdTW and lancerts authored Jan 21, 2025
1 parent a8fa3bb commit 2ea3cfb
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def lce_forward(
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
Expand Down Expand Up @@ -125,14 +126,30 @@ def lce_forward(
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

if version.parse(transformers_version) > version.parse("4.46.2"):
if version.parse(transformers_version) > version.parse("4.46.3"):
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
# https://github.com/huggingface/transformers/issues/33401
# While correct, this breaks equivalence with past versions of Qwen2-VL from
# transformers and leads to failed tests or users noticing differences in results.
# TODO: remove above conditional when liger drops support for transformers<4.47.0
if position_ids is None and input_ids is not None:
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

outputs = self.model(
input_ids=None,
Expand All @@ -144,6 +161,7 @@ def lce_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down

0 comments on commit 2ea3cfb

Please sign in to comment.