diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index b176a685c963..99a05d54832f 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -923,11 +923,28 @@ def load_model_hook(models, input_dir): transformer_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") - else: transformer_ = FluxTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer" ).to(accelerator.device, weight_dtype) + + # Handle input dimension doubling before adding adapter + with torch.no_grad(): + initial_input_channels = transformer_.config.in_channels + new_linear = torch.nn.Linear( + transformer_.x_embedder.in_features * 2, + transformer_.x_embedder.out_features, + bias=transformer_.x_embedder.bias is not None, + dtype=transformer_.dtype, + device=transformer_.device, + ) + new_linear.weight.zero_() + new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight) + if transformer_.x_embedder.bias is not None: + new_linear.bias.copy_(transformer_.x_embedder.bias) + transformer_.x_embedder = new_linear + transformer_.register_to_config(in_channels=initial_input_channels * 2) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)