Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for HunYuanDit ControlNet #4245

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit

import comfy.ldm.hydit.controlnet

def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
Expand Down Expand Up @@ -370,9 +370,116 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

class ControlNetWarperHunyuanDiT(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None

dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

def get_tensor(name):
if name in cond:
if isinstance(cond[name], torch.Tensor):
return cond[name].to(dtype)
else:
return cond[name]
else:
return None

encoder_hidden_states = get_tensor('c_crossattn')
text_embedding_mask = get_tensor('text_embedding_mask')
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
image_meta_size = get_tensor('image_meta_size')
style = get_tensor('style')
cos_cis_img = get_tensor('cos_cis_img')
sin_cis_img = get_tensor('sin_cis_img')

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(
x=x_noisy.to(dtype),
t=timestep.float(),
condition=self.cond_hint,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
**self.extra_args
)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c

def load_controlnet_hunyuandit(controlnet_data):

supported_inference_dtypes = [torch.float16, torch.float32]

unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
missing, unexpected = control_model.load_state_dict(controlnet_data)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))

latent_format = comfy.latent_formats.SDXL()
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)

if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)

Expand Down
Loading
Loading