Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch > 1 in HunyuanVideo #10548

Merged
merged 1 commit into from
Jan 14, 2025
Merged

Fix batch > 1 in HunyuanVideo #10548

merged 1 commit into from
Jan 14, 2025

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Jan 13, 2025

What does this PR do?

Fixes #10542

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@betterftr
Copy link

I still get error with this here:

epoch:   0%|                                                                                   | 0/200 [00:11<?, ?it/s]
Traceback (most recent call last):
  File "C:\OneTrainer\modules\ui\TrainUI.py", line 561, in __training_thread_function
    trainer.train()
  File "C:\OneTrainer\modules\trainer\GenericTrainer.py", line 682, in train
    model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
  File "C:\OneTrainer\modules\modelSetup\BaseHunyuanVideoSetup.py", line 317, in predict
    predicted_flow = model.transformer(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\diffusers\models\transformers\transformer_hunyuan_video.py", line 770, in forward
    hidden_states, encoder_hidden_states = block(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\modules\util\checkpointing_util.py", line 119, in forward
    return checkpoint(
  File "C:\OneTrainer\venv\lib\site-packages\torch\_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\_dynamo\eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\utils\checkpoint.py", line 489, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "C:\OneTrainer\venv\lib\site-packages\torch\autograd\function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "C:\OneTrainer\venv\lib\site-packages\torch\utils\checkpoint.py", line 264, in forward
    outputs = run_function(*args)
  File "C:\OneTrainer\modules\util\checkpointing_util.py", line 89, in offloaded_custom_forward
    output = orig_forward(*args)
  File "C:\OneTrainer\venv\lib\site-packages\diffusers\models\transformers\transformer_hunyuan_video.py", line 478, in forward
    attn_output, context_attn_output = self.attn(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\diffusers\models\attention_processor.py", line 588, in forward
    return self.processor(
  File "C:\OneTrainer\venv\lib\site-packages\diffusers\models\transformers\transformer_hunyuan_video.py", line 117, in __call__
    hidden_states = F.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (24) must match the existing size (16) at non-singleton dimension 1.  Target sizes: [16, 24, 1085, 1085].  Tensor sizes: [16, 1, 1085]

@hlky
Copy link
Collaborator Author

hlky commented Jan 13, 2025

@betterftr

  File "C:\OneTrainer\venv\lib\site-packages\diffusers\models\transformers\transformer_hunyuan_video.py", line 770, in forward
    hidden_states, encoder_hidden_states = block(

line 770

hidden_states, encoder_hidden_states = block(

It is line 771 on this branch, can you confirm you are testing with this PR?

@betterftr
Copy link

betterftr commented Jan 13, 2025

yea I copied the entire from file: https://github.com/huggingface/diffusers/blob/451ca0af739f16fa93aa5028d8fa24a08ae85cdc/src/diffusers/models/transformers/transformer_hunyuan_video.py

image

image

actually wait you might be right my trainer does not use this but the one here venv\lib\site-packages\diffusers\models\transformers\transformer_hunyuan_video.py

and i copied to /src/diffusers

re-running the test
edit 3: Okay you were right, my bad, the issue is gone, sorry :)

@DN6 DN6 requested a review from a-r-r-o-w January 13, 2025 16:06
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It's working as expected. I wonder why our batch inference tests didn't catch this in both occurences

@DN6 DN6 merged commit 4a4afd5 into huggingface:main Jan 14, 2025
12 checks passed
DN6 pushed a commit that referenced this pull request Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Hunyuan Video Batch Size > 1 is broken again
5 participants