diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index b2b21675054c..6b6de7dae3cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -605,10 +605,14 @@ def infer_diffusers_model_type(checkpoint): if any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] ): - if checkpoint["img_in.weight"].shape[1] == 384: - model_type = "flux-fill" + if "model.diffusion_model.img_in.weight" in checkpoint: + key = "model.diffusion_model.img_in.weight" + else: + key = "img_in.weight" - elif checkpoint["img_in.weight"].shape[1] == 128: + if checkpoint[key].shape[1] == 384: + model_type = "flux-fill" + elif checkpoint[key].shape[1] == 128: model_type = "flux-depth" else: model_type = "flux-dev" diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py new file mode 100644 index 000000000000..0ec97db26a9e --- /dev/null +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import ( + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class FluxTransformer2DModelSingleFileTests(unittest.TestCase): + model_class = FluxTransformer2DModel + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + + repo_id = "black-forest-labs/FLUX.1-dev" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache()