Skip to content

Commit

Permalink
[Safetensors] Make safetensors the default way of saving weights (#4235)
Browse files Browse the repository at this point in the history
* make safetensors default

* set default save method as safetensors

* update tests

* update to support saving safetensors

* update test to account for safetensors default

* update example tests to use safetensors

* update example to support safetensors

* update unet tests for safetensors

* fix failing loader tests

* fix qc issues

* fix pipeline tests

* fix example test

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
patrickvonplaten and DN6 authored Aug 17, 2023
1 parent 852dc76 commit 029fb41
Show file tree
Hide file tree
Showing 17 changed files with 126 additions and 97 deletions.
36 changes: 30 additions & 6 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path

import numpy as np
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -296,14 +297,19 @@ def __getitem__(self, index):
return example


def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir):
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
"""Saves the new token embeddings from the text encoder."""
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin")
filename = f"{output_dir}/{y}.bin"

if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, filename)


def parse_args(input_args=None):
Expand Down Expand Up @@ -605,6 +611,11 @@ def parse_args(input_args=None):
action="store_true",
help="Dont apply augmentation during data augmentation when this flag is enabled.",
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1244,8 +1255,15 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir)
unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
save_new_embed(
text_encoder,
modifier_token_id,
accelerator,
args,
args.output_dir,
safe_serialization=not args.no_safe_serialization,
)

# Final inference
# Load previous pipeline
Expand All @@ -1256,9 +1274,15 @@ def main(args):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin")
weight_name = (
"pytorch_custom_diffusion_weights.safetensors"
if not args.no_safe_serialization
else "pytorch_custom_diffusion_weights.bin"
)
pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
for token in args.modifier_token:
pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin")
token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)

# run inference
if args.validation_prompt and args.num_validation_images > 0:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")

# run inference
images = []
Expand Down
46 changes: 25 additions & 21 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import unittest
from typing import List

import torch
import safetensors
from accelerate.utils import write_basic_config

from diffusers import DiffusionPipeline, UNet2DConditionModel
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_train_unconditional(self):

run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_textual_inversion(self):
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_dreambooth(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_dreambooth_if(self):
Expand All @@ -170,7 +170,7 @@ def test_dreambooth_if(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_dreambooth_checkpointing(self):
Expand Down Expand Up @@ -272,10 +272,10 @@ def test_dreambooth_lora(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)
Expand Down Expand Up @@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -437,6 +437,7 @@ def test_custom_diffusion(self):
--lr_scheduler constant
--lr_warmup_steps 0
--modifier_token <new1>
--no_safe_serialization
--output_dir {tmpdir}
""".split()

Expand Down Expand Up @@ -466,7 +467,7 @@ def test_text_to_image(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image_checkpointing(self):
Expand Down Expand Up @@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
Expand Down Expand Up @@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self):

run_command(self._launch_args + test_args)

self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--no_safe_serialization
""".split()

run_command(self._launch_args + test_args)
Expand All @@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--dataloader_num_workers=0
--max_train_steps=9
--checkpointing_steps=2
--no_safe_serialization
""".split()

run_command(self._launch_args + test_args)
Expand All @@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--no_safe_serialization
""".split()

run_command(self._launch_args + resume_run_args)
Expand Down Expand Up @@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand All @@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down
32 changes: 28 additions & 4 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
import PIL
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -157,15 +158,19 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
return images


def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
logger.info("Saving embeddings")
learned_embeds = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)

if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, save_path)


def parse_args():
Expand Down Expand Up @@ -409,6 +414,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -878,7 +888,14 @@ def main():
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)

if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
Expand Down Expand Up @@ -936,7 +953,14 @@ def main():
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)

if args.push_to_hub:
save_model_card(
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ def save_attn_procs(
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
**kwargs,
):
r"""
Save an attention processor to a directory so that it can be reloaded using the
Expand All @@ -514,7 +515,8 @@ def save_attn_procs(
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
Expand Down Expand Up @@ -1414,7 +1416,7 @@ def save_lora_weights(
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Expand All @@ -1435,6 +1437,8 @@ def save_lora_weights(
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
Expand Down
Loading

0 comments on commit 029fb41

Please sign in to comment.