Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attention_2 2.7.2.post1 seems to crash when using torch.compile and DataCollatorWithFlattening #35588

Open
4 tasks
avishaiElmakies opened this issue Jan 9, 2025 · 4 comments
Labels

Comments

@avishaiElmakies
Copy link
Contributor

avishaiElmakies commented Jan 9, 2025

System Info

  • transformers version: 4.47.1
  • Platform: Linux-6.6.20-aufs-1-x86_64-with-glibc2.36
  • Python version: 3.11.2
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A5000

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

update to latest flash attention version (as the time of writing 2.7.2). this should be torch.compile compatible as described in https://github.com/Dao-AILab/flash-attention
load a model with fa2 (tested with opt and qwen)
use trainer with DataCollatorWithFlattening and train.

this causes a crash with the following stacktrace:

Traceback (most recent call last):
  File "/cs/labs/oabend/avishai.elma/slm_eval/slm_eval/train.py", line 89, in main
    trainer.train(resume_from_checkpoint=cfg.cont_training)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/trainer.py", line 2524, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/slm_eval/trainer/slam_trainer.py", line 71, in training_step
    return super().training_step(model, inputs, num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/trainer.py", line 3654, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/trainer.py", line 3708, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/slm_eval/model/unit_lm.py", line 118, in forward
    def forward(self,
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1109, in forward
    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 584, in forward
    def forward(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 364, in forward
    def forward(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 419, in torch_dynamo_resume_in_forward_at_419
    logger.warning_once(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 231, in _flash_attention_forward
    def _flash_attention_forward(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 329, in torch_dynamo_resume_in__flash_attention_forward_at_329
    max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 774, in call_method
    return self.call_apply(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 699, in call_apply
    ).call_function(tx, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 2015, in call_function
    (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
                                            ^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 462, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/torch.py", line 897, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2124, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2082, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
    return fn()
           ^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2018, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function flash_attn._flash_attn_varlen_forward(*(FakeTensor(..., device='cuda:0', size=(s3, s4, s5), dtype=torch.float16,
           grad_fn=<AsStridedBackward0>), FakeTensor(..., device='cuda:0', size=(s6, s7, s8), dtype=torch.float16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(s9, s10, s11), dtype=torch.float16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(s13,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(s13,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), 0.0, FloatPow(ToFloat(s5), -0.5)), **{'causal': True, 'window_size_left': -1, 'window_size_right': -1, 'softcap': 0.0, 'alibi_slopes': None, 'return_softmax': False, 'block_table': None}):
flash_attn::_flash_attn_varlen_forward() Expected a value of type 'int' for argument 'max_seqlen_q' but instead found type 'FakeTensor'.
Position: 5
Value: FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)
Declaration: flash_attn::_flash_attn_varlen_forward(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0., Tensor? alibi_slopes=None, bool return_softmax=False, Tensor? block_table=None, Tensor? leftpad_k=None, Tensor? seqused_k=None) -> (Tensor, Tensor, Tensor, Tensor)
Cast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

from user code:
   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 346, in torch_dynamo_resume_in__flash_attention_forward_at_335
    attn_output = flash_attn_varlen_func(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1412, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
  File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 901, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

the code works fine when not using compile.
the code doesn't crash when using compile but not using DataCollatorWithFlattening.
when using compile and not using DataCollatorWithFlattening I am getting the following graph break with qwen2.5

W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] Graph break from `Tensor.item()`, consider setting:
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] or:
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] to include these operations in the captured graph.
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] 
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] Graph break: from user code at:
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/accelerate/utils/operations.py", line 823, in forward
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     return model_forward(*args, **kwargs)
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/accelerate/utils/operations.py", line 811, in __call__
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     return convert_to_fp32(self.model_forward(*args, **kwargs))
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     return func(*args, **kwargs)
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/slm_eval/model/unit_lm.py", line 138, in forward
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     outputs = self.lm(
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     outputs = self.model(
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 864, in forward
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     causal_mask = self._update_causal_mask(
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]   File "/cs/labs/oabend/avishai.elma/slm_eval/.slm_env2/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 943, in _update_causal_mask
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0]     if attention_mask is not None and 0.0 in attention_mask:
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] 
W0109 16:15:01.606000 4138800 torch/_dynamo/variables/tensor.py:776] [0/0] 

Expected behavior

the training shouldn't crash.

@Rocketknight1
Copy link
Member

Maybe @muellerzr @SunMarc? Feel free to ping someone else if you think they're more appropriate

@SunMarc
Copy link
Member

SunMarc commented Jan 10, 2025

For the issue with qwen2, it is solved with #35187.
As for DataCollatorWithFlattening, I don't have an answer yet without debugging. Looks like there is an issue with an arg. Can you share a minimal reproducer ?

@avishaiElmakies
Copy link
Contributor Author

avishaiElmakies commented Jan 11, 2025

Hi @SunMarc

this seems like code that reproduces the issue. I think it is not directly related to DataCollatorWithFlattening and more related to the path that the model takes because of the position ids it gets.

from transformers import AutoModelForCausalLM,AutoTokenizer
import torch
tokenizer = AutoTokenizer("facebook/opt-125m")
model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
model = model.cuda()
model = torch.compile(model)
outputs = tokenizer(["Text","Second_text"])
input_ids = []
position_ids = []
for item in outputs["input_ids"]:
    item = torch.as_tensor(item)
    input_ids.append(item)
    position_ids.append(torch.arange(item.shape[-1]))
input_ids = torch.cat(input_ids).unsqueeze(0).cuda()
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
model(input_ids=input_ids,position_ids=position_ids) # crashes with same error

@SunMarc
Copy link
Member

SunMarc commented Jan 13, 2025

Thanks for the reproducer, I'll check soon ! cc @MekkCyber if you have a bit of time to look at this !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants