Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187

Open
wants to merge 29 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 11, 2024

What does this PR do?

As discussed with Dhruv, this PR improves our logging for LoRAs when the specified state dict isn't an appropriate one, leaving the load_lora_weights(), load_lora_into_text_encoder(), and load_lora_adapter() methods essentially a no-op.

#9950 mentions this problem, and I think it's better to at least let the user know that the specified LoRA state dict is either not a correct one (which we already do) or is an ineffective one (this PR).

This PR also standardizes a bit how we log these kinds of info across our LoRA loading utilities.

Some comments are in-line.

Comment on lines 297 to 311
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're handling all of this within the load_lora_adapter() method now which I think is more appropriate as:

  1. It takes care of logging when users try to load LoRA into a model via pipe.load_lora_weights().
  2. Users try to load LoRAs directly into a model with the load_lora_adapter() method (with something like unet.load_lora_adapter().

Helps to avoid duplication. I have run the integration tests, too and nothing is breaking due to this.

Comment on lines -663 to -571
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar philosophy as explained above.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul marked this pull request as ready for review December 11, 2024 08:42
@sayakpaul sayakpaul requested review from DN6 and yiyixuxu December 11, 2024 08:42
Comment on lines +300 to +308
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case prefix is not None and no prefix matched state dict keys are found, we log from the load_lora_adapter() method.

This way, we cover for both load_lora_weights() which is pipeline-level and load_lora_adapter() which model-level.

@sayakpaul
Copy link
Member Author

@DN6 a gentle ping.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants