diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fcaa2a19..36bb58f3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -488,7 +488,10 @@ def LlamaModel_fast_forward( # Fix up attention mask by setting elements to 0 # Specifically for DPO - if self._has_no_labels and attention_mask is not None: + if self._has_no_labels and attention_mask is not None and \ + attention_mask.shape[1] == seq_length: + # Careful for inference the attention_mask is size (1, kv_seq_len) + # Whilst the input_embeds is size (1, 1, 4096) inputs_requires_grad = inputs_embeds.requires_grad if inputs_requires_grad: inputs_embeds.requires_grad_(False) inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)