From acc152b674fd1c983acc6efd8aedbeb380660c0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 19 Feb 2025 00:06:54 +0200 Subject: [PATCH] Support loading and using SkyReels-V1-Hunyuan-I2V (#6862) * Support SkyReels-V1-Hunyuan-I2V * VAE scaling * Fix T2V oops * Proper latent scaling --- comfy/ldm/hunyuan_video/model.py | 2 +- comfy/model_base.py | 9 +++++++++ comfy/model_detection.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fc3a6744413..f3f44584385 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -310,7 +310,7 @@ def block_wrap(args): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): diff --git a/comfy/model_base.py b/comfy/model_base.py index 98f462b327a..0eeaed790df 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -871,6 +871,15 @@ def extra_conds(self, **kwargs): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if image is not None: + padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4]) + latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype) + image_latents = torch.cat([image.to(noise), latent_padding], dim=2) + out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents)) + guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2644dd0dc23..5051f821d4f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = 16 + dit_config["in_channels"] = state_dict["img_in.proj.weight"].shape[1] #SkyReels img2video has 32 input channels dit_config["patch_size"] = [1, 2, 2] dit_config["out_channels"] = 16 dit_config["vec_in_dim"] = 768