Skip to content

Commit

Permalink
Merge branch 'main' into auraflow-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Jan 6, 2025
2 parents 5700e52 + 1896b1f commit 6da81f8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 7 deletions.
18 changes: 14 additions & 4 deletions examples/community/rerender_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput, deprecate, logging
from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -775,7 +782,7 @@ def __call__(
self.attn_state.reset()

# 4.1 prepare frames
image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
first_image = image[0] # C, H, W

# 4.2 Prepare controlnet_conditioning_image
Expand Down Expand Up @@ -919,8 +926,8 @@ def __call__(
prev_image = frames[idx - 1]
control_image = control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)

warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
self.flow_model, first_image, image[0], first_result, False, self.device
Expand Down Expand Up @@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

return latents

if mask_start_t <= mask_end_t:
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ def _process_lora(
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,30 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s


class TextToVideoZeroPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
Expand Down Expand Up @@ -440,6 +458,10 @@ def backward_loop(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

return latents.clone().detach()

# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
Expand Down
4 changes: 4 additions & 0 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -126,6 +128,8 @@ def test_modify_padding_mode(self):
pass


@slow
@nightly
@require_torch_gpu
@require_peft_backend
class LoraSD3IntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit 6da81f8

Please sign in to comment.