-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187
base: main
Are you sure you want to change the base?
Conversation
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | ||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as | ||
# their prefixes. | ||
keys = list(state_dict.keys()) | ||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) | ||
if not only_text_encoder: | ||
# Load the layers corresponding to UNet. | ||
logger.info(f"Loading {cls.unet_name}.") | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're handling all of this within the load_lora_adapter()
method now which I think is more appropriate as:
- It takes care of logging when users try to load LoRA into a model via
pipe.load_lora_weights()
. - Users try to load LoRAs directly into a model with the
load_lora_adapter()
method (with something likeunet.load_lora_adapter()
.
Helps to avoid duplication. I have run the integration tests, too and nothing is breaking due to this.
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||
if len(text_encoder_state_dict) > 0: | ||
self.load_lora_into_text_encoder( | ||
text_encoder_state_dict, | ||
network_alphas=network_alphas, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar philosophy as explained above.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
logger.info(f"Loading {cls.unet_name}.") | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case prefix
is not None and no prefix
matched state dict keys are found, we log from the load_lora_adapter()
method.
This way, we cover for both load_lora_weights()
which is pipeline-level and load_lora_adapter()
which model-level.
@DN6 a gentle ping. |
What does this PR do?
As discussed with Dhruv, this PR improves our logging for LoRAs when the specified state dict isn't an appropriate one, leaving the
load_lora_weights()
,load_lora_into_text_encoder()
, andload_lora_adapter()
methods essentially a no-op.#9950 mentions this problem, and I think it's better to at least let the user know that the specified LoRA state dict is either not a correct one (which we already do) or is an ineffective one (this PR).
This PR also standardizes a bit how we log these kinds of info across our LoRA loading utilities.
Some comments are in-line.