Skip to content

Commit

Permalink
Move buffers to device (#10523)
Browse files Browse the repository at this point in the history
* Move buffers to device

* add test

* named_buffers
  • Loading branch information
hlky authored Jan 16, 2025
1 parent b785ddb commit 0b065c0
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

else:
Expand Down
17 changes: 16 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union

import safetensors
import torch
Expand Down Expand Up @@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
Expand Down Expand Up @@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else:
set_module_tensor_to_device(model, param_name, device, value=param)

if named_buffers is None:
return unexpected_keys

for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)

return unexpected_keys


Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" those weights or else make sure your checkpoint file is correct."
)

named_buffers = model.named_buffers()

unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
Expand All @@ -921,6 +923,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

if cls._keys_to_ignore_on_load_unexpected is not None:
Expand Down
36 changes: 35 additions & 1 deletion tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import pytest
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -302,6 +309,33 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.cuda()


class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
)

def tearDown(self):
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()

def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
self.assertEqual(
buffer.device.type,
torch.device(torch_device).type,
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
)


class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
Expand Down

0 comments on commit 0b065c0

Please sign in to comment.