From 8b3bef7b9d6ab4191892d061fad7732fdb93dde4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 24 Dec 2023 09:59:41 +0530 Subject: [PATCH] [LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225) * initialize alpha too. * add: test * remove config parsing * store rank * debug * remove faulty test --- examples/dreambooth/train_dreambooth_lora.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_sdxl.py | 10 ++++++++-- examples/text_to_image/train_text_to_image_lora.py | 5 ++++- .../text_to_image/train_text_to_image_lora_sdxl.py | 10 ++++++++-- tests/lora/test_lora_layers_peft.py | 8 ++++++-- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 55ef2bbeb8eb7..67132d6d88df8 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -827,6 +827,7 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( r=args.rank, + lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) @@ -835,7 +836,10 @@ def main(args): # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder.add_adapter(text_lora_config) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8a3ac294fef2c..0f41ad47d1acb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -978,7 +978,10 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -986,7 +989,10 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c8efbddd0b441..d6d0dee0883ce 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -452,7 +452,10 @@ def main(): param.requires_grad_(False) unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) # Move unet, vae and text_encoder to device and cast to weight_dtype diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index be17c13c28850..d95fcbbba0338 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -609,7 +609,10 @@ def main(args): # now we will add new LoRA weights to the attention layers # Set correct lora layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -618,7 +621,10 @@ def main(args): if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index f6cd2a714ae24..30125f64f6ac1 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests: def get_dummy_components(self, scheduler_cls=None): scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler + rank = 4 torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) @@ -125,11 +126,14 @@ def get_dummy_components(self, scheduler_cls=None): tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + r=rank, + lora_alpha=rank, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, ) unet_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)