Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Pyramid Attention Broadcast #9562

Merged
merged 80 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
67c729d
start pyramid attention broadcast
a-r-r-o-w Oct 1, 2024
6d3bdb5
add coauthor
a-r-r-o-w Oct 3, 2024
3737101
update
a-r-r-o-w Oct 3, 2024
d5c738d
make style
a-r-r-o-w Oct 3, 2024
1c97e04
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
ae4abb1
update
a-r-r-o-w Oct 3, 2024
955e4f7
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
9f6987f
make style
a-r-r-o-w Oct 3, 2024
b3547c6
add docs
a-r-r-o-w Oct 4, 2024
afd0c17
add tests
a-r-r-o-w Oct 4, 2024
6265b65
update
a-r-r-o-w Oct 5, 2024
6816fe1
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 15, 2024
9cb4e87
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
6b1f55e
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
37d2366
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 30, 2024
a5f51bb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 31, 2024
18b7d6d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 5, 2024
c52cf42
Pyramid Attention Broadcast rewrite + introduce hooks (#9826)
a-r-r-o-w Nov 8, 2024
3de2c18
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 9, 2024
d95d61a
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 6, 2024
6090575
merge pyramid-attention-rewrite-2
a-r-r-o-w Dec 9, 2024
af51f5d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 9, 2024
903514f
make style
a-r-r-o-w Dec 9, 2024
b690db2
remove changes from latte transformer
a-r-r-o-w Dec 9, 2024
63ab886
revert docs changes
a-r-r-o-w Dec 9, 2024
d40bced
better debug message
a-r-r-o-w Dec 9, 2024
0ea904e
add todos for future
a-r-r-o-w Dec 9, 2024
9d452dc
update tests
a-r-r-o-w Dec 9, 2024
cfe3921
make style
a-r-r-o-w Dec 9, 2024
b972c4b
cleanup
a-r-r-o-w Dec 9, 2024
2b558ff
fix
a-r-r-o-w Dec 9, 2024
0b2629d
improve log message; fix latte test
a-r-r-o-w Dec 9, 2024
9182f57
refactor
a-r-r-o-w Dec 9, 2024
d974401
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
62b5b8d
update
a-r-r-o-w Dec 30, 2024
bb250d6
update
a-r-r-o-w Dec 30, 2024
cbc086f
update
a-r-r-o-w Dec 30, 2024
7debcec
revert changes to tests
a-r-r-o-w Dec 30, 2024
a5c34af
update docs
a-r-r-o-w Dec 30, 2024
ad24269
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
bbcde6b
update tests
a-r-r-o-w Dec 30, 2024
b148ab4
Apply suggestions from code review
a-r-r-o-w Dec 31, 2024
d4ecd6c
update
a-r-r-o-w Dec 31, 2024
6cca58f
fix flux test
a-r-r-o-w Dec 31, 2024
c2e0e3b
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 31, 2024
d9fad00
reorder
a-r-r-o-w Jan 2, 2025
35296eb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
2436b3f
refactor
a-r-r-o-w Jan 13, 2025
95c8148
make fix-copies
a-r-r-o-w Jan 13, 2025
76afc6a
update docs
a-r-r-o-w Jan 13, 2025
fb66167
fixes
a-r-r-o-w Jan 13, 2025
1040c91
more fixes
a-r-r-o-w Jan 13, 2025
ffbabb5
make style
a-r-r-o-w Jan 13, 2025
1b92b1d
update tests
a-r-r-o-w Jan 13, 2025
88d917d
update code example
a-r-r-o-w Jan 13, 2025
e4d8b12
make fix-copies
a-r-r-o-w Jan 13, 2025
cc94647
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
071a0ba
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
ae8bd99
refactor based on reviews
a-r-r-o-w Jan 15, 2025
a9ee5a4
use maybe_free_model_hooks
a-r-r-o-w Jan 15, 2025
1a59688
CacheMixin
a-r-r-o-w Jan 15, 2025
c8616a6
make style
a-r-r-o-w Jan 15, 2025
08a209d
update
a-r-r-o-w Jan 15, 2025
15e645d
add current_timestep property; update docs
a-r-r-o-w Jan 15, 2025
d6ce4ab
make fix-copies
a-r-r-o-w Jan 15, 2025
96fce86
update
a-r-r-o-w Jan 15, 2025
107e375
improve tests
a-r-r-o-w Jan 15, 2025
f7d7e38
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
40fc7a5
try circular import fix
a-r-r-o-w Jan 15, 2025
248f103
apply suggestions from review
a-r-r-o-w Jan 15, 2025
0a290a6
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 16, 2025
fe93975
address review comments
a-r-r-o-w Jan 16, 2025
2b59994
Apply suggestions from code review
a-r-r-o-w Jan 17, 2025
a8e460e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 22, 2025
8c74a7a
refactor hook implementation
a-r-r-o-w Jan 22, 2025
3f3e26a
add test suite for hooks
a-r-r-o-w Jan 22, 2025
83d221f
PAB Refactor (#10667)
a-r-r-o-w Jan 27, 2025
847760e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 27, 2025
3d269ad
update
a-r-r-o-w Jan 27, 2025
5535fd6
fix remove hook behaviour
a-r-r-o-w Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# CogVideoX

[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://huggingface.co/papers/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.

The abstract from the paper is:

Expand Down Expand Up @@ -100,6 +100,45 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)

### Pyramid Attention Broadcast

[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.

Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation.
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with CogVideoX:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16)
pipe.to("cuda")

pipe.enable_pyramid_attention_broadcast(
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
spatial_attn_skip_range=2,
spatial_attn_timestep_range=[100, 850],
)

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```

a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 |
| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 |
| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 |

## CogVideoXPipeline

[[autodoc]] CogVideoXPipeline
Expand Down
33 changes: 32 additions & 1 deletion docs/source/en/api/pipelines/latte.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

![latte text-to-video](https://github.com/Vchitect/Latte/blob/52bc0029899babbd6e9250384c83d8ed2670ff7a/visuals/latte.gif?raw=true)

[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.

The abstract from the paper is:

Expand Down Expand Up @@ -70,6 +70,37 @@ Without torch.compile(): Average inference time: 16.246 seconds.
With torch.compile(): Average inference time: 14.573 seconds.
```

### Pyramid Attention Broadcast
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.

Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation.

PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with Latte:

```python
import torch
from diffusers import LattePipeline
from diffusers.utils import export_to_gif

pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)

pipe.enable_pyramid_attention_broadcast(
spatial_attn_skip_range=2,
cross_attn_skip_range=6,
spatial_attn_timestep_range=[100, 800],
cross_attn_timestep_range=[100, 800],
)

prompt = "A small cactus with a happy face in the Sahara desert."
videos = pipe(prompt).frames[0]
export_to_gif(videos, "latte.gif")
```

| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 |

## LattePipeline

[[autodoc]] LattePipeline
Expand Down
64 changes: 63 additions & 1 deletion src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from typing import Dict, Optional, Union

import torch
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
Expand Down Expand Up @@ -165,6 +167,66 @@ def __init__(
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()

for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

return processors

for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)

return processors

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.

Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.

If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.

"""
count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import CogVideoXPipelineOutput


Expand Down Expand Up @@ -137,7 +138,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.

Expand Down Expand Up @@ -605,6 +606,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False

# 2. Default call parameters
Expand Down Expand Up @@ -674,6 +676,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -729,6 +732,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

self._current_timestep = None

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import CogVideoXPipelineOutput


Expand Down Expand Up @@ -152,7 +153,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class CogVideoXImageToVideoPipeline(DiffusionPipeline):
class CogVideoXImageToVideoPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
r"""
Pipeline for image-to-video generation using CogVideoX.

Expand Down Expand Up @@ -679,6 +680,7 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False

# 2. Default call parameters
Expand Down Expand Up @@ -753,6 +755,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -810,6 +813,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

self._current_timestep = None

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import CogVideoXPipelineOutput


Expand Down Expand Up @@ -159,7 +160,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
r"""
Pipeline for video-to-video generation using CogVideoX.

Expand Down Expand Up @@ -679,6 +680,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False

# 2. Default call parameters
Expand Down Expand Up @@ -755,6 +757,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -810,6 +813,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

self._current_timestep = None

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
Expand Down
13 changes: 9 additions & 4 deletions src/diffusers/pipelines/latte/pipeline_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -132,7 +133,7 @@ class LattePipelineOutput(BaseOutput):
frames: torch.Tensor


class LattePipeline(DiffusionPipeline):
class LattePipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
r"""
Pipeline for text-to-video generation using Latte.

Expand Down Expand Up @@ -623,7 +624,7 @@ def __call__(
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
decode_chunk_size: Optional[int] = None,
decode_chunk_size: int = 14,
) -> Union[LattePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -719,6 +720,7 @@ def __call__(
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False

# 2. Default height and width to transformer
Expand Down Expand Up @@ -780,6 +782,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -836,8 +839,10 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if not output_type == "latents":
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
self._current_timestep = None

if not output_type == "latent":
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
Expand Down
Loading
Loading