Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA PEFT] fix LoRA loading so that correct alphas are parsed #6225

Merged
merged 9 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
Expand All @@ -835,7 +836,10 @@ def main(args):
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder.add_adapter(text_lora_config)

Expand Down
10 changes: 8 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,15 +978,21 @@ def main(args):

# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
Expand Down
5 changes: 4 additions & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def main():
param.requires_grad_(False)

unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

# Move unet, vae and text_encoder to device and cast to weight_dtype
Expand Down
10 changes: 8 additions & 2 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)
Expand All @@ -635,7 +638,10 @@ def main(args):
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
Expand Down
63 changes: 60 additions & 3 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None

def get_dummy_components(self, scheduler_cls=None):
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
lora_alpha = 4 if lora_alpha is None else lora_alpha

torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
Expand All @@ -123,11 +124,14 @@ def get_dummy_components(self, scheduler_cls=None):
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")

text_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
r=4,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
)

unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)

unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
Expand Down Expand Up @@ -714,6 +718,59 @@ def test_simple_inference_with_text_unet_lora_unloaded(self):
"Fused lora should change the output",
)

def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8

components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe.unet.add_adapter(unet_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)

# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images

with tempfile.TemporaryDirectory() as tmpdirname:
unet_state_dict = get_peft_model_state_dict(pipe.unet)
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)

if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)

self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)

# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images

assert (
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."

def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
Expand Down
Loading