From 22f12d7e83b51121dd69e0b9233a8185d1bdbb72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Sep 2024 13:27:39 +0530 Subject: [PATCH 1/2] clean up the serialization stuff. --- tests/lora/utils.py | 108 +++++++++++++++++--------------------------- 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 939b749c286a..c4412bb89967 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -201,6 +201,32 @@ def get_dummy_tokens(self): prepared_inputs["input_ids"] = inputs return prepared_inputs + def _get_lora_state_dicts(self, modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + + def _get_modules_to_save(self, pipe, has_denoiser=False): + modules_to_save = {} + lora_loadable_modules = self.pipeline_class._lora_loadable_modules + + if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"): + modules_to_save["text_encoder"] = pipe.text_encoder + + if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"): + modules_to_save["text_encoder_2"] = pipe.text_encoder_2 + + if has_denoiser: + if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): + modules_to_save["unet"] = pipe.unet + + if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): + modules_to_save["transformer"] = pipe.transformer + + return modules_to_save + def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected @@ -420,38 +446,17 @@ def test_simple_inference_with_text_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - safe_serialization=False, - ) - else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - safe_serialization=False, - ) + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - if self.has_two_text_encoders: - if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - safe_serialization=False, - ) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: @@ -460,6 +465,8 @@ def test_simple_inference_with_text_lora_save_load(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results.", @@ -614,54 +621,21 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = ( - get_peft_model_state_dict(pipe.text_encoder) - if "text_encoder" in self.pipeline_class._lora_loadable_modules - else None + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts ) - denoiser_state_dict = get_peft_model_state_dict(denoiser) - - saving_kwargs = { - "save_directory": tmpdirname, - "safe_serialization": False, - } - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict}) - - if self.unet_kwargs is not None: - saving_kwargs.update({"unet_lora_layers": denoiser_state_dict}) - else: - saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict}) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict}) - - self.pipeline_class.save_lora_weights(**saving_kwargs) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + # Verify that LoRA layers are correctly set + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results.", From 0697afbc20676ec4d1d64dbdb90e4c80a9050a5f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Sep 2024 13:35:19 +0530 Subject: [PATCH 2/2] better --- tests/lora/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index c4412bb89967..f11420d263e2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -457,13 +457,8 @@ def test_simple_inference_with_text_lora_save_load(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -631,7 +626,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - # Verify that LoRA layers are correctly set for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")