Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

static cache with mixtral will cause CUDA error: device-side assert triggered #35626

Open
1 of 4 tasks
zyxiyy opened this issue Jan 11, 2025 · 1 comment
Open
1 of 4 tasks
Labels

Comments

@zyxiyy
Copy link

zyxiyy commented Jan 11, 2025

System Info

None

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

code

with torch.no_grad():
    past_key_values = StaticCache(
        config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype,layer_device_map=layer_device_map
    )
    cache_position = torch.arange(seq_length, device=torch_device)
    generated_ids = torch.zeros(
        batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
    )
    generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

    logits = model(
        **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
    )[0]
    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    generated_ids[:, seq_length] = next_token[:, 0]

    # decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
    cache_position = torch.tensor([seq_length + 1], device=torch_device)
    for _ in range(1, NUM_TOKENS_TO_GENERATE):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
            generated_ids[:, cache_position] = next_token.int()
        cache_position += 1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)
## error:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/bcds/venv/dilab/floe/static_cache_test.ipynb 单元格 3 line 2
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> for _ in range(1, NUM_TOKENS_TO_GENERATE):
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a>     with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
---> <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a>         next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a>         generated_ids[:, cache_position] = next_token.int()
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a>     cache_position += 1

/home/bcds/venv/dilab/floe/static_cache_test.ipynb 单元格 3 line 1
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
---> <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18'>19</a>     logits = model(
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a>         cur_token,
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a>         position_ids=input_pos,
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a>         cache_position=cache_position,
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a>         past_key_values=past_key_values,
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a>         return_dict=False,
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>         use_cache=True
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>     )[0]
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>     new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
     <a href='vscode-notebook-cell://ssh-remote%2B10.1.3.1/home/bcds/venv/dilab/floe/static_cache_test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>     return new_token

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:1283, in MixtralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1280 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1282 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1283 outputs = self.model(
   1284     input_ids=input_ids,
   1285     attention_mask=attention_mask,
   1286     position_ids=position_ids,
   1287     past_key_values=past_key_values,
   1288     inputs_embeds=inputs_embeds,
   1289     use_cache=use_cache,
   1290     output_attentions=output_attentions,
   1291     output_hidden_states=output_hidden_states,
   1292     output_router_logits=output_router_logits,
   1293     return_dict=return_dict,
   1294     cache_position=cache_position,
   1295 )
   1297 hidden_states = outputs[0]
   1298 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:998, in MixtralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position)
    986     layer_outputs = self._gradient_checkpointing_func(
    987         decoder_layer.__call__,
    988         hidden_states,
   (...)
    995         cache_position,
    996     )
    997 else:
--> 998     layer_outputs = decoder_layer(
    999         hidden_states,
   1000         attention_mask=causal_mask,
   1001         position_ids=position_ids,
   1002         past_key_value=past_key_values,
   1003         output_attentions=output_attentions,
   1004         output_router_logits=output_router_logits,
   1005         use_cache=use_cache,
   1006         cache_position=cache_position,
   1007     )
   1009 hidden_states = layer_outputs[0]
   1011 if use_cache:

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:724, in MixtralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, cache_position, **kwargs)
    721 hidden_states = self.input_layernorm(hidden_states)
    723 # Self Attention
--> 724 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    725     hidden_states=hidden_states,
    726     attention_mask=attention_mask,
    727     position_ids=position_ids,
    728     past_key_value=past_key_value,
    729     output_attentions=output_attentions,
    730     use_cache=use_cache,
    731     cache_position=cache_position,
    732 )
    733 hidden_states = residual + hidden_states
    735 # Fully Connected

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:544, in MixtralSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    541     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
    542     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
--> 544 key_states = repeat_kv(key_states, self.num_key_value_groups)
    545 value_states = repeat_kv(value_states, self.num_key_value_groups)
    547 causal_mask = attention_mask

File ~/.conda/envs/llm/lib/python3.9/site-packages/transformers/models/mixtral/modeling_mixtral.py:262, in repeat_kv(hidden_states, n_rep)
    260     return hidden_states
    261 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
--> 262 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


### Expected behavior

generate correctly
@zyxiyy zyxiyy added the bug label Jan 11, 2025
@Rocketknight1
Copy link
Member

cc @gante I think this is a "compiling the entire generation loop" issue again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants