diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index 474c68fc5..a51d59d0b 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -36,6 +36,7 @@ def lce_forward( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy @@ -125,14 +126,30 @@ def lce_forward( if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) - if version.parse(transformers_version) > version.parse("4.46.2"): + if version.parse(transformers_version) > version.parse("4.46.3"): # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0 # https://github.com/huggingface/transformers/issues/33401 # While correct, this breaks equivalence with past versions of Qwen2-VL from # transformers and leads to failed tests or users noticing differences in results. # TODO: remove above conditional when liger drops support for transformers<4.47.0 - if position_ids is None and input_ids is not None: - position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, @@ -144,6 +161,7 @@ def lce_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0]