We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
None
No response
examples
with torch.no_grad(): past_key_values = StaticCache( config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype,layer_device_map=layer_device_map ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device ) generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) logits = model( **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True )[0] next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids[:, seq_length] = next_token[:, 0] # decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) generated_ids[:, cache_position] = next_token.int() cache_position += 1 text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) print(text) ## error: --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /home/bcds/venv/dilab/floe/static_cache_test.ipynb 单元格 3 line 2 <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> for _ in range(1, NUM_TOKENS_TO_GENERATE): <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): ---> <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a> generated_ids[:, cache_position] = next_token.int() <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> cache_position += 1 /home/bcds/venv/dilab/floe/static_cache_test.ipynb 单元格 3 line 1 <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): ---> <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18'>19</a> logits = model( <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> cur_token, <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> position_ids=input_pos, <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> cache_position=cache_position, <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a> past_key_values=past_key_values, <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> return_dict=False, <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a> use_cache=True <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a> )[0] <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a> new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a> return new_token File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) 168 output = module._old_forward(*args, **kwargs) 169 else: --> 170 output = module._old_forward(*args, **kwargs) 171 return module._hf_hook.post_forward(module, output) File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:1283, in MixtralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position, num_logits_to_keep, **loss_kwargs) 1280 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1282 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1283 outputs = self.model( 1284 input_ids=input_ids, 1285 attention_mask=attention_mask, 1286 position_ids=position_ids, 1287 past_key_values=past_key_values, 1288 inputs_embeds=inputs_embeds, 1289 use_cache=use_cache, 1290 output_attentions=output_attentions, 1291 output_hidden_states=output_hidden_states, 1292 output_router_logits=output_router_logits, 1293 return_dict=return_dict, 1294 cache_position=cache_position, 1295 ) 1297 hidden_states = outputs[0] 1298 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:998, in MixtralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position) 986 layer_outputs = self._gradient_checkpointing_func( 987 decoder_layer.__call__, 988 hidden_states, (...) 995 cache_position, 996 ) 997 else: --> 998 layer_outputs = decoder_layer( 999 hidden_states, 1000 attention_mask=causal_mask, 1001 position_ids=position_ids, 1002 past_key_value=past_key_values, 1003 output_attentions=output_attentions, 1004 output_router_logits=output_router_logits, 1005 use_cache=use_cache, 1006 cache_position=cache_position, 1007 ) 1009 hidden_states = layer_outputs[0] 1011 if use_cache: File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) 168 output = module._old_forward(*args, **kwargs) 169 else: --> 170 output = module._old_forward(*args, **kwargs) 171 return module._hf_hook.post_forward(module, output) File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:724, in MixtralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, cache_position, **kwargs) 721 hidden_states = self.input_layernorm(hidden_states) 723 # Self Attention --> 724 hidden_states, self_attn_weights, present_key_value = self.self_attn( 725 hidden_states=hidden_states, 726 attention_mask=attention_mask, 727 position_ids=position_ids, 728 past_key_value=past_key_value, 729 output_attentions=output_attentions, 730 use_cache=use_cache, 731 cache_position=cache_position, 732 ) 733 hidden_states = residual + hidden_states 735 # Fully Connected File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) 168 output = module._old_forward(*args, **kwargs) 169 else: --> 170 output = module._old_forward(*args, **kwargs) 171 return module._hf_hook.post_forward(module, output) File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:544, in MixtralSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position) 541 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models 542 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) --> 544 key_states = repeat_kv(key_states, self.num_key_value_groups) 545 value_states = repeat_kv(value_states, self.num_key_value_groups) 547 causal_mask = attention_mask File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:262, in repeat_kv(hidden_states, n_rep) 260 return hidden_states 261 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) --> 262 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ### Expected behavior generate correctly
The text was updated successfully, but these errors were encountered:
cc @gante I think this is a "compiling the entire generation loop" issue again
Sorry, something went wrong.
No branches or pull requests
System Info
None
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
code
The text was updated successfully, but these errors were encountered: