Skip to content

Commit

Permalink
Add correct number of channels when resuming from checkpoint for Flux…
Browse files Browse the repository at this point in the history
… Control LoRa training (#10422)

* Add correct number of channels when resuming from checkpoint

* Fix Formatting
  • Loading branch information
thesantatitan authored Jan 2, 2025
1 parent 91008aa commit 4b9f1c7
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

else:
transformer_ = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
).to(accelerator.device, weight_dtype)

# Handle input dimension doubling before adding adapter
with torch.no_grad():
initial_input_channels = transformer_.config.in_channels
new_linear = torch.nn.Linear(
transformer_.x_embedder.in_features * 2,
transformer_.x_embedder.out_features,
bias=transformer_.x_embedder.bias is not None,
dtype=transformer_.dtype,
device=transformer_.device,
)
new_linear.weight.zero_()
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
if transformer_.x_embedder.bias is not None:
new_linear.bias.copy_(transformer_.x_embedder.bias)
transformer_.x_embedder = new_linear
transformer_.register_to_config(in_channels=initial_input_channels * 2)

transformer_.add_adapter(transformer_lora_config)

lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
Expand Down

0 comments on commit 4b9f1c7

Please sign in to comment.