From 029fb41695a7940c213d914471fb41a1df67aa17 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 17 Aug 2023 07:24:28 +0200 Subject: [PATCH] [Safetensors] Make safetensors the default way of saving weights (#4235) * make safetensors default * set default save method as safetensors * update tests * update to support saving safetensors * update test to account for safetensors default * update example tests to use safetensors * update example to support safetensors * update unet tests for safetensors * fix failing loader tests * fix qc issues * fix pipeline tests * fix example test --------- Co-authored-by: Dhruv Nair --- .../train_custom_diffusion.py | 36 ++++++++++++--- examples/dreambooth/train_dreambooth_lora.py | 2 +- examples/test_examples.py | 46 ++++++++++--------- .../textual_inversion/textual_inversion.py | 32 +++++++++++-- src/diffusers/loaders.py | 10 ++-- src/diffusers/models/modeling_utils.py | 4 +- .../pipelines/controlnet/multicontrolnet.py | 4 +- src/diffusers/pipelines/pipeline_utils.py | 4 +- .../pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- tests/models/test_lora_layers.py | 43 ++++------------- tests/models/test_modeling_common.py | 10 ++-- tests/models/test_models_unet_2d_condition.py | 12 ++--- tests/models/test_models_unet_3d_condition.py | 8 ++-- tests/pipelines/test_pipelines.py | 2 +- tests/pipelines/test_pipelines_common.py | 4 +- 17 files changed, 126 insertions(+), 97 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index a5b4b0846f26..9ec38bdd0435 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -26,6 +26,7 @@ from pathlib import Path import numpy as np +import safetensors import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -296,14 +297,19 @@ def __getitem__(self, index): return example -def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir): +def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True): """Saves the new token embeddings from the text encoder.""" logger.info("Saving embeddings") learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight for x, y in zip(modifier_token_id, args.modifier_token): learned_embeds_dict = {} learned_embeds_dict[y] = learned_embeds[x] - torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin") + filename = f"{output_dir}/{y}.bin" + + if safe_serialization: + safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"}) + else: + torch.save(learned_embeds_dict, filename) def parse_args(input_args=None): @@ -605,6 +611,11 @@ def parse_args(input_args=None): action="store_true", help="Dont apply augmentation during data augmentation when this flag is enabled.", ) + parser.add_argument( + "--no_safe_serialization", + action="store_true", + help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1244,8 +1255,15 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) - save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir) + unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization) + save_new_embed( + text_encoder, + modifier_token_id, + accelerator, + args, + args.output_dir, + safe_serialization=not args.no_safe_serialization, + ) # Final inference # Load previous pipeline @@ -1256,9 +1274,15 @@ def main(args): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin") + weight_name = ( + "pytorch_custom_diffusion_weights.safetensors" + if not args.no_safe_serialization + else "pytorch_custom_diffusion_weights.bin" + ) + pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name) for token in args.modifier_token: - pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin") + token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin" + pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name) # run inference if args.validation_prompt and args.num_validation_images > 0: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 72d4ab77e0d2..cc1b5df89542 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") # run inference images = [] diff --git a/examples/test_examples.py b/examples/test_examples.py index 4fd2e485cd0f..c12154a0d572 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -23,7 +23,7 @@ import unittest from typing import List -import torch +import safetensors from accelerate.utils import write_basic_config from diffusers import DiffusionPipeline, UNet2DConditionModel @@ -93,7 +93,7 @@ def test_train_unconditional(self): run_command(self._launch_args + test_args, return_stdout=True) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) def test_textual_inversion(self): @@ -144,7 +144,7 @@ def test_dreambooth(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) def test_dreambooth_if(self): @@ -170,7 +170,7 @@ def test_dreambooth_if(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) def test_dreambooth_checkpointing(self): @@ -272,10 +272,10 @@ def test_dreambooth_lora(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) @@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # check `text_encoder` is present at all. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) keys = lora_state_dict.keys() is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) self.assertTrue(is_text_encoder_present) @@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) @@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) @@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) @@ -437,6 +437,7 @@ def test_custom_diffusion(self): --lr_scheduler constant --lr_warmup_steps 0 --modifier_token + --no_safe_serialization --output_dir {tmpdir} """.split() @@ -466,7 +467,7 @@ def test_text_to_image(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) def test_text_to_image_checkpointing(self): @@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): @@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self): run_command(self._launch_args + test_args) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): --max_train_steps=6 --checkpoints_total_limit=2 --checkpointing_steps=2 + --no_safe_serialization """.split() run_command(self._launch_args + test_args) @@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple --dataloader_num_workers=0 --max_train_steps=9 --checkpointing_steps=2 + --no_safe_serialization """.split() run_command(self._launch_args + test_args) @@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple --checkpointing_steps=2 --resume_from_checkpoint=checkpoint-8 --checkpoints_total_limit=3 + --no_safe_serialization """.split() run_command(self._launch_args + resume_run_args) @@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) @@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 78dd578b2d4e..6c11efc82139 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -24,6 +24,7 @@ import numpy as np import PIL +import safetensors import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -157,7 +158,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight return images -def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path): +def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True): logger.info("Saving embeddings") learned_embeds = ( accelerator.unwrap_model(text_encoder) @@ -165,7 +166,11 @@ def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_p .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] ) learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} - torch.save(learned_embeds_dict, save_path) + + if safe_serialization: + safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(learned_embeds_dict, save_path) def parse_args(): @@ -409,6 +414,11 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--no_safe_serialization", + action="store_true", + help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -878,7 +888,14 @@ def main(): global_step += 1 if global_step % args.save_steps == 0: save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") - save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: @@ -936,7 +953,14 @@ def main(): pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings save_path = os.path.join(args.output_dir, "learned_embeds.bin") - save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) if args.push_to_hub: save_model_card( diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e53040dc2db3..81404e4c9968 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -497,7 +497,8 @@ def save_attn_procs( is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, + **kwargs, ): r""" Save an attention processor to a directory so that it can be reloaded using the @@ -514,7 +515,8 @@ def save_attn_procs( The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. - + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ from .models.attention_processor import ( CustomDiffusionAttnProcessor, @@ -1414,7 +1416,7 @@ def save_lora_weights( is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1435,6 +1437,8 @@ def save_lora_weights( The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ # Create a flat dictionary. state_dict = {} diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b575c9cdb25e..e53fa7e528b7 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -272,7 +272,7 @@ def save_pretrained( save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, variant: Optional[str] = None, push_to_hub: bool = False, **kwargs, @@ -292,7 +292,7 @@ def save_pretrained( The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 2214611d26e7..7d284f2d26d3 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -77,7 +77,7 @@ def save_pretrained( save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, variant: Optional[str] = None, ): """ @@ -95,7 +95,7 @@ def save_pretrained( The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 75cc0eae8cb9..669018d11a17 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -556,7 +556,7 @@ def __setattr__(self, name: str, value: Any): def save_pretrained( self, save_directory: Union[str, os.PathLike], - safe_serialization: bool = False, + safe_serialization: bool = True, variant: Optional[str] = None, push_to_hub: bool = False, **kwargs, @@ -569,7 +569,7 @@ class implements both a save and loading method. The pipeline is easily reloaded Arguments: save_directory (`str` or `os.PathLike`): Directory to save a pipeline to. Will be created if it doesn't exist. - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index dcc77ead9fd1..bf6c625bb2b6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -904,7 +904,7 @@ def save_lora_weights( is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, ): state_dict = {} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index f468f902f6a3..d07405d45bfc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1058,7 +1058,7 @@ def save_lora_weights( is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, ): state_dict = {} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 8e12d9888b57..c480549aebb3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1338,7 +1338,7 @@ def save_lora_weights( is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, - safe_serialization: bool = False, + safe_serialization: bool = True, ): state_dict = {} diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index e5461e560648..7f06da81ba38 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -201,7 +201,7 @@ def create_lora_weight_file(self, tmpdirname): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) def test_lora_save_load(self): pipeline_components, lora_components = self.get_dummy_components() @@ -220,33 +220,6 @@ def test_lora_save_load(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - - def test_lora_save_load_safetensors(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - original_images = sd_pipe(**pipeline_inputs).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], - safe_serialization=True, - ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) @@ -256,7 +229,7 @@ def test_lora_save_load_safetensors(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - def test_lora_save_load_legacy(self): + def test_lora_save_load_no_safe_serialization(self): pipeline_components, lora_components = self.get_dummy_components() unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] sd_pipe = StableDiffusionPipeline(**pipeline_components) @@ -271,7 +244,7 @@ def test_lora_save_load_legacy(self): with tempfile.TemporaryDirectory() as tmpdirname: unet = sd_pipe.unet unet.set_attn_processor(unet_lora_attn_procs) - unet.save_attn_procs(tmpdirname) + unet.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(tmpdirname) @@ -368,7 +341,7 @@ def test_text_encoder_lora_scale(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) lora_images = sd_pipe(**pipeline_inputs).images @@ -425,7 +398,7 @@ def test_unload_lora_sd(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images @@ -501,7 +474,7 @@ def test_lora_save_load_with_xformers(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) lora_images = sd_pipe(**pipeline_inputs).images @@ -629,7 +602,7 @@ def test_lora_save_load(self): text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) lora_images = sd_pipe(**pipeline_inputs).images @@ -658,7 +631,7 @@ def test_unload_lora_sdxl(self): text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(tmpdirname) lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index b9d1f924d78c..9260f16caafd 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -52,7 +52,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): model = torch.compile(model) with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) + model.save_pretrained(tmpdirname, safe_serialization=False) new_model = model_class.from_pretrained(tmpdirname) new_model.to(torch_device) @@ -205,7 +205,7 @@ def test_from_save_pretrained(self): model.eval() with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) + model.save_pretrained(tmpdirname, safe_serialization=False) new_model = self.model_class.from_pretrained(tmpdirname) if hasattr(new_model, "set_default_attn_processor"): new_model.set_default_attn_processor() @@ -327,7 +327,7 @@ def test_from_save_pretrained_variant(self): model.eval() with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, variant="fp16") + model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False) new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") if hasattr(new_model, "set_default_attn_processor"): new_model.set_default_attn_processor() @@ -372,7 +372,7 @@ def test_from_save_pretrained_dtype(self): continue with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) - model.save_pretrained(tmpdirname) + model.save_pretrained(tmpdirname, safe_serialization=False) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) assert new_model.dtype == dtype new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) @@ -429,7 +429,7 @@ def test_model_from_pretrained(self): # test if the model can be loaded from the config # and has all the expected shape with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) + model.save_pretrained(tmpdirname, safe_serialization=False) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) new_model.eval() diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index bd0a89fcefa0..84bd3ecb6b4c 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -579,7 +579,7 @@ def test_lora_save_load(self): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) + model.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) @@ -643,12 +643,12 @@ def test_lora_save_safetensors_load_torch(self): model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") def test_lora_save_torch_force_load_safetensors_error(self): # enable deterministic behavior for gradient checkpointing @@ -664,7 +664,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) + model.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) @@ -775,7 +775,7 @@ def test_custom_diffusion_save_load(self): sample = model(**inputs_dict).sample with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) + model.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 72a33854bdcd..ed42c582e889 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -252,7 +252,7 @@ def test_lora_save_load(self): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) + model.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) @@ -316,11 +316,11 @@ def test_lora_save_safetensors_load_torch(self): # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) new_model.to(torch_device) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors") def test_lora_save_torch_force_load_safetensors_error(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -335,7 +335,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) + model.save_attn_procs(tmpdirname, safe_serialization=False) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 6e59859bdc7a..a46007672247 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -884,7 +884,7 @@ def test_custom_model_and_pipeline(self): ) with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname, safe_serialization=False) pipe_new = CustomPipeline.from_pretrained(tmpdirname) pipe_new.save_pretrained(tmpdirname) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index be7ae1a31500..8d2b97691f52 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -309,7 +309,7 @@ def test_save_load_local(self, expected_max_difference=1e-4): logger.setLevel(diffusers.logging.INFO) with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) + pipe.save_pretrained(tmpdir, safe_serialization=False) with CaptureLogger(logger) as cap_logger: pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) @@ -597,7 +597,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): output = pipe(**inputs)[0] with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) + pipe.save_pretrained(tmpdir, safe_serialization=False) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None)