From 0b065c099a9ebbe75206763ca6ef307820df01cc Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 16 Jan 2025 17:42:56 +0000 Subject: [PATCH] Move buffers to device (#10523) * Move buffers to device * add test * named_buffers --- src/diffusers/loaders/single_file_model.py | 2 ++ src/diffusers/models/model_loading_utils.py | 17 +++++++++- src/diffusers/models/modeling_utils.py | 3 ++ tests/quantization/bnb/test_mixed_int8.py | 36 ++++++++++++++++++++- 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 69ab8b6bad20..c7d0fcb3046e 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -362,6 +362,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if is_accelerate_available(): param_device = torch.device(device) if device else torch.device("cpu") + named_buffers = model.named_buffers() unexpected_keys = load_model_dict_into_meta( model, diffusers_format_checkpoint, @@ -369,6 +370,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = device=param_device, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + named_buffers=named_buffers, ) else: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 386c07e8747c..0acf50b82356 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,7 +20,7 @@ from array import array from collections import OrderedDict from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union import safetensors import torch @@ -193,6 +193,7 @@ def load_model_dict_into_meta( model_name_or_path: Optional[str] = None, hf_quantizer=None, keep_in_fp32_modules=None, + named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, ) -> List[str]: if device is not None and not isinstance(device, (str, torch.device)): raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") @@ -254,6 +255,20 @@ def load_model_dict_into_meta( else: set_module_tensor_to_device(model, param_name, device, value=param) + if named_buffers is None: + return unexpected_keys + + for param_name, param in named_buffers: + if is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + else: + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fcd7775fb608..5600cb1e7d78 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -913,6 +913,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " those weights or else make sure your checkpoint file is correct." ) + named_buffers = model.named_buffers() + unexpected_keys = load_model_dict_into_meta( model, state_dict, @@ -921,6 +923,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model_name_or_path=pretrained_model_name_or_path, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + named_buffers=named_buffers, ) if cls._keys_to_ignore_on_load_unexpected is not None: diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 2661196afc70..d1404a2f8929 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -20,7 +20,14 @@ import pytest from huggingface_hub import hf_hub_download -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers import ( + BitsAndBytesConfig, + DiffusionPipeline, + FluxTransformer2DModel, + SanaTransformer2DModel, + SD3Transformer2DModel, + logging, +) from diffusers.utils import is_accelerate_version from diffusers.utils.testing_utils import ( CaptureLogger, @@ -302,6 +309,33 @@ def test_device_and_dtype_assignment(self): _ = self.model_fp16.cuda() +class Bnb8bitDeviceTests(Base8bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SanaTransformer2DModel.from_pretrained( + "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers", + subfolder="transformer", + quantization_config=mixed_int8_config, + ) + + def tearDown(self): + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_buffers_device_assignment(self): + for buffer_name, buffer in self.model_8bit.named_buffers(): + self.assertEqual( + buffer.device.type, + torch.device(torch_device).type, + f"Expected device {torch_device} for {buffer_name} got {buffer.device}.", + ) + + class BnB8bitTrainingTests(Base8bitTests): def setUp(self): gc.collect()