Skip to content

Commit

Permalink
Merge branch 'main' into RMS
Browse files Browse the repository at this point in the history
  • Loading branch information
leisuzz authored Jan 14, 2025
2 parents 701f49f + 74b6752 commit 4f98ac0
Show file tree
Hide file tree
Showing 28 changed files with 433 additions and 92 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(

pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=text_encoder_8bit,
text_encoder_2=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
Expand Down
8 changes: 4 additions & 4 deletions docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.

*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*

<Tip>

Expand Down Expand Up @@ -45,14 +45,14 @@ from diffusers.utils import export_to_video

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
)

pipeline = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/using-diffusers/text-img2vid.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
"hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
)

# reduce memory requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
"phonemizer",
]

# this is a lookup table with items like:
Expand Down Expand Up @@ -227,6 +228,7 @@ def run(self):
"scipy",
"torchvision",
"transformers",
"phonemizer",
)
extras["torch"] = deps_list("torch", "accelerate")

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
}
20 changes: 11 additions & 9 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,17 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except RuntimeError as e:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)

self.peft_config.pop(adapter_name)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)

self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

Expand Down
11 changes: 8 additions & 3 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
"instruct-pix2pix": 512,
"v2": 768,
"v1": 512,
}
Expand Down Expand Up @@ -605,10 +606,14 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
if "model.diffusion_model.img_in.weight" in checkpoint:
key = "model.diffusion_model.img_in.weight"
else:
key = "img_in.weight"

elif checkpoint["img_in.weight"].shape[1] == 128:
if checkpoint[key].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint[key].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
Expand Down
137 changes: 96 additions & 41 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,10 +1010,12 @@ def __init__(
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16

# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
Expand All @@ -1023,8 +1025,10 @@ def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand All @@ -1046,8 +1050,10 @@ def enable_tiling(
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Expand All @@ -1073,18 +1079,13 @@ def disable_slicing(self) -> None:
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
enc = self.encoder(x)

return enc

Expand Down Expand Up @@ -1121,19 +1122,15 @@ def _decode(
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio

if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)

if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)

if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z, temb)
dec = self.decoder(z, temb)

if not return_dict:
return (dec,)
Expand Down Expand Up @@ -1189,6 +1186,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
)
return b

def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b

def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Expand Down Expand Up @@ -1217,17 +1222,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)

row.append(time)
rows.append(row)
Expand Down Expand Up @@ -1283,17 +1280,7 @@ def tiled_decode(
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)

row.append(time)
rows.append(row)
Expand All @@ -1318,6 +1305,74 @@ def tiled_decode(

return DecoderOutput(sample=dec)

def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1

tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames

row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)

result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])

enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc

def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames

row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)

result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])

dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.Tensor,
Expand All @@ -1334,5 +1389,5 @@ def forward(
z = posterior.mode()
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return (dec.sample,)
return dec
Loading

0 comments on commit 4f98ac0

Please sign in to comment.