From 982aa5932a6d057eefb4241f5db027059374ce6b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 13 Jan 2025 04:45:22 +0100 Subject: [PATCH 1/4] update --- src/diffusers/loaders/single_file_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index cefba48275cf..26b48e59efda 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -604,10 +604,16 @@ def infer_diffusers_model_type(checkpoint): if any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] ): - if checkpoint["img_in.weight"].shape[1] == 384: + + if "model.diffusion_model.img_in.weight" in checkpoint: + key = "model.diffusion_model.img_in.weight" + else: + key = "img_in.weight" + + if checkpoint[key].shape[1] == 384: model_type = "flux-fill" - elif checkpoint["img_in.weight"].shape[1] == 128: + elif checkpoint[key].shape[1] == 128: model_type = "flux-depth" else: model_type = "flux-dev" From 8cb29999fca06bb652105a2e04e09f089440001f Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 13 Jan 2025 04:52:28 +0100 Subject: [PATCH 2/4] update --- src/diffusers/loaders/single_file_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 26b48e59efda..faf7ddff13f8 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -604,7 +604,6 @@ def infer_diffusers_model_type(checkpoint): if any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] ): - if "model.diffusion_model.img_in.weight" in checkpoint: key = "model.diffusion_model.img_in.weight" else: From 8c0632aa8ec88fcc61dd8ab02f9deeb677b19b7d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 13 Jan 2025 09:45:02 +0100 Subject: [PATCH 3/4] update --- ...test_model_flux_transformer_single_file.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/single_file/test_model_flux_transformer_single_file.py diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py new file mode 100644 index 000000000000..0ec97db26a9e --- /dev/null +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import ( + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class FluxTransformer2DModelSingleFileTests(unittest.TestCase): + model_class = FluxTransformer2DModel + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + + repo_id = "black-forest-labs/FLUX.1-dev" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache() From 50bda8536a80e2c4dc7becd2a1362fe63c132486 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 13 Jan 2025 09:45:52 +0100 Subject: [PATCH 4/4] update --- src/diffusers/loaders/single_file_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index faf7ddff13f8..4ab1954df1fc 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -611,7 +611,6 @@ def infer_diffusers_model_type(checkpoint): if checkpoint[key].shape[1] == 384: model_type = "flux-fill" - elif checkpoint[key].shape[1] == 128: model_type = "flux-depth" else: