From 2ea3cfb9053f6aa4f13d7a64e239e9150865a6d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B7=A6=E5=85=B6=E5=8F=B3?= <48852791+BenasdTW@users.noreply.github.com> Date: Wed, 22 Jan 2025 07:36:52 +0800 Subject: [PATCH] Handle cache_position for transformers 4.47.0 and later (#528) (#529) ## Summary Fix [issue #528](https://github.com/linkedin/Liger-Kernel/issues/528) by copying the new way to handle RoPE from transformers 4.48.0 ```python if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) ``` ## Testing Done Tested on (all worked with this PR, 4.48.0 didn't work with this PR): pip install transformers==4.46.2 pip install transformers==4.46.3 pip install transformers==4.48.0 Before applying this PR, using training Qwen2-VL using `liger-kernel` with `transformers>=4.47.0` would result in this error ([issue #528](https://github.com/linkedin/Liger-Kernel/issues/528)): ``` Traceback (most recent call last): File "/workspaces/test/t.py", line 51, in generated_ids = model.generate(**inputs, max_new_tokens=128) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2255, in generate result = self._sample( ^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample outputs = self(**model_inputs, return_dict=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: lce_forward() got an unexpected keyword argument 'cache_position' ``` Inference test script: ```python import torch from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from transformers import BitsAndBytesConfig from qwen_vl_utils import process_vision_info apply_liger_kernel_to_qwen2_vl() model_id = "Qwen/Qwen2-VL-2B-Instruct" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = Qwen2VLForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, attn_implementation="flash_attention_2", ) processor = AutoProcessor.from_pretrained(model_id) messages = [ { "role": "user", "content": [ { "type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) ``` Training test script: ```python import torch from datasets import load_dataset from qwen_vl_utils import process_vision_info from transformers import BitsAndBytesConfig from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from peft import LoraConfig, get_peft_model from peft.optimizers import create_loraplus_optimizer import bitsandbytes as bnb from trl import SFTTrainer, SFTConfig from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor from configs_and_helpers import clear_memory, vl_format_data, generate_text_from_sample apply_liger_kernel_to_qwen2_vl() model_id = "Qwen/Qwen2-VL-2B-Instruct" dataset_id = "HuggingFaceM4/ChartQA" system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images. Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase. The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text. Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary.""" train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:20%]", "val[:2%]", "test[:1%]"]) train_dataset = [vl_format_data(sample, system_message) for sample in train_dataset] eval_dataset = [vl_format_data(sample, system_message) for sample in eval_dataset] test_dataset = [vl_format_data(sample, system_message) for sample in test_dataset] model = Qwen2VLForConditionalGeneration.from_pretrained( model_id, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) processor = Qwen2VLProcessor.from_pretrained(model_id) print(f"{train_dataset[0]=}") print(f"{train_dataset[0][1:2]=}") output = generate_text_from_sample(model, processor, train_dataset[0]) print(f"{output=}") clear_memory(globals()) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = Qwen2VLForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, attn_implementation="flash_attention_2", use_cache=False ) processor = Qwen2VLProcessor.from_pretrained(model_id) processor.padding_side = "right" # Ensure padding is added to the right side processor.tokenizer.padding_side = "right" # Ensure padding is added to the right side # Configure LoRA peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.05, r=8, bias="none", target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", ) # Apply PEFT model adaptation print(model) model = get_peft_model(model, peft_config) print(model) # Print trainable parameters model.print_trainable_parameters() optimizer = create_loraplus_optimizer( model=model, optimizer_cls=bnb.optim.PagedAdamW8bit, # optimizer_cls=torch.optim.AdamW, lr=2e-4, eps=1e-6, # eps=1e-8, betas=(0.9, 0.999), weight_decay=0.0, loraplus_lr_ratio=16, ) scheduler = None # Configure training arguments training_args = SFTConfig( output_dir="qwen2-2b-instruct-trl-sft-ChartQA", num_train_epochs=1, per_device_train_batch_size=8, per_device_eval_batch_size=4, gradient_accumulation_steps=8, eval_accumulation_steps=4, # Logging and evaluation logging_steps=1, eval_steps=10, torch_empty_cache_steps=1, eval_strategy="steps", save_strategy="epoch", bf16=True, # Gradient checkpointing settings gradient_checkpointing_kwargs={"use_reentrant": False}, gradient_checkpointing=True, # Dataset configuration dataset_kwargs={"skip_prepare_dataset": True}, # max_seq_length=1024 # Maximum sequence length for input remove_unused_columns = False # Keep unused columns in dataset ) # Create a data collator to encode text and image pairs def collator_fn(examples): # Get the texts and images, and apply the chat template texts = [processor.apply_chat_template(example, tokenize=False) for example in examples] image_inputs = [process_vision_info(example)[0] for example in examples] # Process the images to extract inputs # Tokenize the texts and process the images batch = processor( text=texts, images=image_inputs, return_tensors="pt", padding=True ) # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() # Mask padding tokens in labels labels[labels == processor.tokenizer.pad_token_id] = -100 # Ignore the image token index in the loss computation (model specific) if isinstance(processor, Qwen2VLProcessor): image_tokens = [151652, 151653, 151655] # Specific image token IDs for Qwen2VLProcessor else: image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] # Convert image token to ID # Mask image token IDs in the labels for image_token_id in image_tokens: labels[labels == image_token_id] = -100 batch["labels"] = labels return batch print(f"Processed data:\n{collator_fn(train_dataset[:2])}") trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=collator_fn, peft_config=peft_config, processing_class=processor.tokenizer, optimizers=(optimizer, scheduler) ) trainer.train() trainer.save_model(training_args.output_dir) # model.save_model(output_name) ``` My hardware is fairly weak, OOM running `make test`. Might need further testing. - Hardware Type: Nvidia RTX 4070 Laptop - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang --- .../transformers/model/qwen2_vl.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index 474c68fc5..a51d59d0b 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -36,6 +36,7 @@ def lce_forward( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy @@ -125,14 +126,30 @@ def lce_forward( if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) - if version.parse(transformers_version) > version.parse("4.46.2"): + if version.parse(transformers_version) > version.parse("4.46.3"): # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0 # https://github.com/huggingface/transformers/issues/33401 # While correct, this breaks equivalence with past versions of Qwen2-VL from # transformers and leads to failed tests or users noticing differences in results. # TODO: remove above conditional when liger drops support for transformers<4.47.0 - if position_ids is None and input_ids is not None: - position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, @@ -144,6 +161,7 @@ def lce_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0]