Skip to content

Commit

Permalink
[Tests] add: test to check 8bit bnb quantized models work with lora l…
Browse files Browse the repository at this point in the history
…oading. (#10576)

* add: test to check 8bit bnb quantized models work with lora loading.

* Update tests/quantization/bnb/test_mixed_int8.py

Co-authored-by: Dhruv Nair <[email protected]>

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
sayakpaul and DN6 authored Jan 15, 2025
1 parent 2432f80 commit bba59fb
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import pytest
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
Expand All @@ -30,6 +31,7 @@
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
require_peft_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
Expand Down Expand Up @@ -509,6 +511,29 @@ def test_quality(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)

@require_peft_version_greater("0.14.0")
def test_lora_loading(self):
self.pipeline_8bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)

output = self.pipeline_8bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()

expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)


@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
Expand Down

0 comments on commit bba59fb

Please sign in to comment.