Skip to content

Commit

Permalink
feat: add support for HunYuanDit ControlNet (#4245)
Browse files Browse the repository at this point in the history
* add support for HunYuanDit ControlNet

* fix hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix code format style

* add control_weight support for HunyuanDit Controlnet

* use control_weights in HunyuanDit Controlnet

* fix typo
  • Loading branch information
CrazyBoyM authored Aug 9, 2024
1 parent 4133226 commit 06eb9fb
Show file tree
Hide file tree
Showing 4 changed files with 512 additions and 1 deletion.
109 changes: 108 additions & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit

import comfy.ldm.hydit.controlnet

def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
Expand Down Expand Up @@ -382,9 +382,116 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

class ControlNetWarperHunyuanDiT(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None

dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

def get_tensor(name):
if name in cond:
if isinstance(cond[name], torch.Tensor):
return cond[name].to(dtype)
else:
return cond[name]
else:
return None

encoder_hidden_states = get_tensor('c_crossattn')
text_embedding_mask = get_tensor('text_embedding_mask')
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
image_meta_size = get_tensor('image_meta_size')
style = get_tensor('style')
cos_cis_img = get_tensor('cos_cis_img')
sin_cis_img = get_tensor('sin_cis_img')

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(
x=x_noisy.to(dtype),
t=timestep.float(),
condition=self.cond_hint,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
**self.extra_args
)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c

def load_controlnet_hunyuandit(controlnet_data):

supported_inference_dtypes = [torch.float16, torch.float32]

unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
missing, unexpected = control_model.load_state_dict(controlnet_data)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))

latent_format = comfy.latent_formats.SDXL()
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)

if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)

Expand Down
Loading

0 comments on commit 06eb9fb

Please sign in to comment.