From 05624642802c7f90dcc7aeea0e1c8d447cde006e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jan 2024 17:49:54 +1100 Subject: [PATCH] Fix inference attention mask (#142) * faster saving & inference * Update llama.py * Update save.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * fast inference * Update llama.py * Update save.py * Update llama.py * Mistral correct RoPE scaling * Max sequence lengths * Apache 2 * fast_linear_forward * Update utils.py * Update utils.py * No print * Update utils.py * Update utils.py * inference * Update llama.py * Fast inference RoPE * Update llama.py * Update llama.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * LoRA * Fast LoRA saving * Update llama.py * hidden_states * q_len == 1 * q_len issue * Update mistral.py * Update mistral.py * incorrect inference * Update to transformers 4.37 * Graceful FA2 error + torch 2.1.1 * Update mapper.py * Update pyproject.toml * Fix saving and bnb-4bit * Update fast_lora.py * Update fast_lora.py * remove patching * Update llama.py * Update llama.py * Update swiglu.py * Repatch * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update llama.py * Update fast_lora.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update swiglu.py * Update fast_lora.py * Update swiglu.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update save.py * Update fast_lora.py * Update utils.py * Update llama.py * Update fast_lora.py * Update swiglu.py * Update save.py * Update save.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Revert "Update llama.py" This reverts commit a208ec46e012cf470ecefe6268a66358215df7b6. * Update llama.py * Works? * Update pyproject.toml * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Swiglu * Update swiglu.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update swiglu.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * attention_mask * Update llama.py * Update llama.py * labels * Update mistral.py * Update llama.py * attention mask * Update save.py * Update save.py * Update mistral.py * attention mask * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update dpo.py * Patch saving * Update save.py * Update save.py * patch_saving_functions * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * print * Mistral patch * Update mistral.py * Update save.py * saving * Update llama.py * Update llama.py --- unsloth/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fcaa2a19..36bb58f3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -488,7 +488,10 @@ def LlamaModel_fast_forward( # Fix up attention mask by setting elements to 0 # Specifically for DPO - if self._has_no_labels and attention_mask is not None: + if self._has_no_labels and attention_mask is not None and \ + attention_mask.shape[1] == seq_length: + # Careful for inference the attention_mask is size (1, kv_seq_len) + # Whilst the input_embeds is size (1, 1, 4096) inputs_requires_grad = inputs_embeds.requires_grad if inputs_requires_grad: inputs_embeds.requires_grad_(False) inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)