Skip to content

Commit

Permalink
Merge branch 'main' into layerwise-upcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Aug 19, 2024
2 parents c64fa22 + 940b8e0 commit 51a855c
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 164 deletions.
3 changes: 3 additions & 0 deletions examples/textual_inversion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ import torch
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")

repo_id_embeds = "path-to-your-learned-embeds"
pipe.load_textual_inversion(repo_id_embeds)

prompt = "A <cat-toy> backpack"

image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
Expand Down
47 changes: 21 additions & 26 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,49 +89,44 @@
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])


def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []

sf_filenames = set()

passed_components = passed_components or []

# extract all components of the pipeline and their associated files
components = {}
for filename in filenames:
_, extension = os.path.splitext(filename)
if not len(filename.split("/")) == 2:
continue

if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
component, component_filename = filename.split("/")
if component in passed_components:
continue

if extension == ".bin":
pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
sf_filenames.add(os.path.normpath(filename))
components.setdefault(component, [])
components[component].append(component_filename)

for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
filename, extension = os.path.splitext(component_filename)

if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename
match_exists = extension == ".safetensors"
matches.append(match_exists)

expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
if not any(matches):
return False

return True
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,18 +1416,14 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
)
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
):
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
Expand Down
5 changes: 2 additions & 3 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
Expand Down Expand Up @@ -80,8 +80,7 @@ def test_sd3_lora(self):
Related PR: https://github.com/huggingface/diffusers/pull/8584
"""
components = self.get_dummy_components()

pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components[0])
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

Expand Down
65 changes: 0 additions & 65 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,71 +124,6 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
lora_filename = "daiton-xl-lora-test.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213])

max_diff = numpy_cosine_similarity_distance(expected, images)
assert max_diff < 1e-3
pipe.unload_lora_weights()
release_memory(pipe)

def test_sdxl_0_9_lora_two(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
lora_filename = "saijo.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626])

max_diff = numpy_cosine_similarity_distance(expected, images)
assert max_diff < 1e-3

pipe.unload_lora_weights()
release_memory(pipe)

def test_sdxl_0_9_lora_three(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468])

max_diff = numpy_cosine_similarity_distance(expected, images)
assert max_diff < 5e-3

pipe.unload_lora_weights()
release_memory(pipe)

def test_sdxl_1_0_lora(self):
generator = torch.Generator("cpu").manual_seed(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
enable_full_determinism()


class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down Expand Up @@ -73,3 +73,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
def test_set_attn_processor_for_determinism(self):
pass
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
4 changes: 4 additions & 0 deletions tests/pipelines/aura_flow/test_pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,7 @@ def test_fused_qkv_projections(self):
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
pass
4 changes: 4 additions & 0 deletions tests/pipelines/lumina/test_lumina_nextdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def test_lumina_prompt_embeds(self):
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

@unittest.skip("xformers attention processor does not exist for Lumina")
def test_xformers_attention_forwardGenerator_pass(self):
pass


@slow
@require_torch_gpu
Expand Down
71 changes: 46 additions & 25 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,21 @@ def test_all_is_compatible_variant(self):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
Expand All @@ -99,25 +95,14 @@ def test_diffusers_model_is_not_compatible_variant(self):
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
self.assertFalse(is_safetensors_compatible(filenames))

def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_transformer_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))

def test_transformer_model_is_not_compatible_variant(self):
filenames = [
Expand All @@ -126,9 +111,45 @@ def test_transformer_model_is_not_compatible_variant(self):
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
self.assertFalse(is_safetensors_compatible(filenames))

def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model-00001-of-00002.safetensors",
"text_encoder/model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_transformers_is_compatible_variant_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_variant_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
Loading

0 comments on commit 51a855c

Please sign in to comment.