Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] [LoRA] clean up the serialization stuff. #9512

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 41 additions & 73 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,32 @@ def get_dummy_tokens(self):
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts

def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules

if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
modules_to_save["text_encoder"] = pipe.text_encoder

if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2

if has_denoiser:
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
modules_to_save["unet"] = pipe.unet

if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
modules_to_save["transformer"] = pipe.transformer

return modules_to_save

def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
Expand Down Expand Up @@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self):
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)

self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)

if self.has_two_text_encoders:
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()

pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
Expand Down Expand Up @@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = (
get_peft_model_state_dict(pipe.text_encoder)
if "text_encoder" in self.pipeline_class._lora_loadable_modules
else None
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)

denoiser_state_dict = get_peft_model_state_dict(denoiser)

saving_kwargs = {
"save_directory": tmpdirname,
"safe_serialization": False,
}

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})

if self.unet_kwargs is not None:
saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
else:
saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})

self.pipeline_class.save_lora_weights(**saving_kwargs)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()

pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)

self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
Expand Down
Loading