Skip to content

Commit

Permalink
Fix Flux multiple Lora loading bug (#10388)
Browse files Browse the repository at this point in the history
* check for base_layer key in transformer state dict

* test_lora_expansion_works_for_absent_keys

* check

* Update tests/lora/test_lora_layers_flux.py

Co-authored-by: Sayak Paul <[email protected]>

* check

* test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys

* absent->extra

---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Jan 2, 2025
1 parent 4b9f1c7 commit 44640c8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
continue

base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
f"{k.replace(prefix, '')}.base_layer.weight"
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
Expand Down
100 changes: 100 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import sys
Expand Down Expand Up @@ -162,6 +163,105 @@ def test_with_alpha_in_state_dict(self):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))

def test_lora_expansion_works_for_absent_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)

# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")

pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")

# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}

pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images

self.assertFalse(
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

def test_lora_expansion_works_for_extra_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)

# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")

pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")

# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")

pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images

self.assertFalse(
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down

0 comments on commit 44640c8

Please sign in to comment.