From 9d3d68f55d6246d80b6afbdc9e3818b0d07935da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:06:48 +0530 Subject: [PATCH 1/9] handle dora. --- src/diffusers/loaders/lora_pipeline.py | 41 ++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ba1435a8cbdc..48cb755f24dc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -99,7 +99,13 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -562,7 +568,14 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1125,7 +1138,13 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1659,7 +1678,13 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -2405,7 +2430,13 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From 1b71c5c5becaf9623beb3351d7fe927a9b5a2ae4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:08:33 +0530 Subject: [PATCH 2/9] print test --- tests/lora/test_lora_layers_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 4ec7ef897485..3589c9582eda 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -628,6 +628,8 @@ def test_integration_logits_for_dora_lora(self): ).images predicted_slice = images[0, -3:, -3:, -1].flatten() + from diffusers.utils.testing_utils import print_tensor_test + print_tensor_test(predicted_slice) expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) assert max_diff < 1e-3 From 7a32ee21f41ce7789066154fded3a0057911422e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:19:54 +0530 Subject: [PATCH 3/9] debug --- src/diffusers/loaders/lora_pipeline.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 48cb755f24dc..a8a1da3c565c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -570,10 +570,13 @@ def load_lora_weights( ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) + print(f"{is_dora_scale_present=}") if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + print(f"{is_dora_scale_present=}") is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: From 9d7e3a2eae9b70cd553b712e4c0916d0c0b0aca4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:36:42 +0530 Subject: [PATCH 4/9] fix --- src/diffusers/loaders/lora_pipeline.py | 39 ++++++++++++-------------- tests/lora/test_lora_layers_sdxl.py | 1 + 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a8a1da3c565c..84ae728f055d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -99,12 +99,6 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -217,6 +211,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -569,15 +568,6 @@ def load_lora_weights( **kwargs, ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - print(f"{is_dora_scale_present=}") - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - print(f"{is_dora_scale_present=}") - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -700,6 +690,14 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + print(f"{is_dora_scale_present=}") + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + print(f"{is_dora_scale_present=}") network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -1609,6 +1607,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. @@ -1681,12 +1684,6 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 3589c9582eda..7e0f0924c33b 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -629,6 +629,7 @@ def test_integration_logits_for_dora_lora(self): predicted_slice = images[0, -3:, -3:, -1].flatten() from diffusers.utils.testing_utils import print_tensor_test + print_tensor_test(predicted_slice) expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) From 6fcb40b25d45a9f6ee8afe6f7478d3657c8b5612 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:37:28 +0530 Subject: [PATCH 5/9] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 84ae728f055d..3ff760361737 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -691,13 +691,10 @@ def lora_state_dict( allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) - print(f"{is_dora_scale_present=}") if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - print(f"{is_dora_scale_present=}") network_alphas = None # TODO: replace it with a method from `state_dict_utils` From 97d13a5668a76284d4b006e73c09ade817f413fd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 10:40:25 +0530 Subject: [PATCH 6/9] update logits --- tests/lora/test_lora_layers_sdxl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 7e0f0924c33b..c87a5e0a85d5 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -628,9 +628,6 @@ def test_integration_logits_for_dora_lora(self): ).images predicted_slice = images[0, -3:, -3:, -1].flatten() - from diffusers.utils.testing_utils import print_tensor_test - - print_tensor_test(predicted_slice) - expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) + expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) assert max_diff < 1e-3 From a8fdc76865824cb2b9e6d0e4db4a4a5187d2fb6d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 6 Oct 2024 14:52:48 +0530 Subject: [PATCH 7/9] add warning in the test. --- tests/lora/test_lora_layers_sdxl.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index c87a5e0a85d5..8deecd770c31 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -33,8 +33,10 @@ StableDiffusionXLPipeline, T2IAdapter, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( + CaptureLogger, load_image, nightly, numpy_cosine_similarity_distance, @@ -620,12 +622,16 @@ def test_integration_logits_for_dora_lora(self): pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") pipeline.enable_model_cpu_offload() - images = pipeline( - "photo of ohwx dog", - num_inference_steps=10, - generator=torch.manual_seed(0), - output_type="np", - ).images + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + images = pipeline( + "photo of ohwx dog", + num_inference_steps=10, + generator=torch.manual_seed(0), + output_type="np", + ).images + assert "It seems like you are using a DoRA checkpoint" in cap_logger.out predicted_slice = images[0, -3:, -3:, -1].flatten() expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886]) From e08cf7442b710c3c3fffe36d536e37c2e20ac7da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Oct 2024 08:33:17 +0530 Subject: [PATCH 8/9] make is_dora check consistent. --- src/diffusers/loaders/lora_pipeline.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3ff760361737..9f22a77b75e0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1100,6 +1100,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -1136,12 +1142,6 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1611,7 +1611,6 @@ def lora_state_dict( state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) @@ -2395,6 +2394,11 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} return state_dict @@ -2427,12 +2431,6 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") From 1c253e2a40f27909b4d39935d339d22c87745568 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Oct 2024 08:40:10 +0530 Subject: [PATCH 9/9] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9f22a77b75e0..8c8f2dfa84f8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2394,6 +2394,7 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."