Skip to content

Commit

Permalink
[SDXL] Make watermarker optional under certain circumstances to impro…
Browse files Browse the repository at this point in the history
…ve usability of SDXL 1.0 (#4346)

* improve sdxl

* more fixes

* improve sdxl

* improve sdxl

* improve sdxl

* finish
  • Loading branch information
patrickvonplaten authored and sayakpaul committed Jul 28, 2023
1 parent 9cde56a commit c3e3a1e
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,25 @@ You can install the libraries as follows:
pip install transformers
pip install accelerate
pip install safetensors
```

### Watermarker

We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:

```
pip install invisible-watermark>=0.2.0
```

If the `invisible-watermark` library is installed the watermarker will be used **by default**.

If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:

```py
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
```

### Text-to-Image

You can use SDXL as follows for *text-to-image*:
Expand Down
1 change: 0 additions & 1 deletion examples/controlnet/requirements_sdxl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0
datasets
wandb
1 change: 0 additions & 1 deletion examples/dreambooth/requirements_sdxl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0
19 changes: 5 additions & 14 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
TextToVideoSDPipeline,
Expand All @@ -202,20 +207,6 @@
VQDiffusionPipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .pipelines import (
StableDiffusionXLControlNetPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
Expand Down
22 changes: 7 additions & 15 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_invisible_watermark_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
Expand Down Expand Up @@ -51,6 +50,7 @@
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
Expand Down Expand Up @@ -108,6 +108,12 @@
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
from .t2i_adapter import StableDiffusionAdapterPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
Expand All @@ -121,20 +127,6 @@
from .vq_diffusion import VQDiffusionPipeline


try:
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .controlnet import StableDiffusionXLControlNetPipeline
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
11 changes: 1 addition & 10 deletions src/diffusers/pipelines/controlnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_invisible_watermark_available,
is_torch_available,
is_transformers_available,
)


try:
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline


try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
Expand All @@ -26,6 +16,7 @@
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline


if is_transformers_available() and is_flax_available():
Expand Down
22 changes: 19 additions & 3 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from diffusers.utils.import_utils import is_invisible_watermark_available

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
Expand All @@ -42,7 +44,11 @@
)
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker


if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker

from .multicontrolnet import MultiControlNetModel


Expand Down Expand Up @@ -109,6 +115,7 @@ def __init__(
controlnet: ControlNetModel,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()

Expand All @@ -130,7 +137,13 @@ def __init__(
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None

self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
Expand Down Expand Up @@ -995,7 +1008,10 @@ def __call__(
image = latents
return StableDiffusionXLPipelineOutput(images=image)

image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ...utils import (
BaseOutput,
OptionalDependencyNotAvailable,
is_invisible_watermark_available,
is_torch_available,
is_transformers_available,
)
Expand All @@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput):


try:
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
from ...utils.dummy_torch_and_transformers_and_objects import * # noqa F403
else:
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker


if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -125,6 +129,7 @@ def __init__(
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()

Expand All @@ -142,7 +147,12 @@ def __init__(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size

self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -839,7 +849,10 @@ def __call__(
image = latents
return StableDiffusionXLPipelineOutput(images=image)

image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker


if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -131,6 +135,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()

Expand All @@ -148,7 +153,12 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -990,7 +1000,10 @@ def denoising_value_valid(dnv):
image = latents
return StableDiffusionXLPipelineOutput(images=image)

image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,20 @@
XFormersAttnProcessor,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker


if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -265,6 +275,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()

Expand All @@ -282,7 +293,12 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -1266,6 +1282,10 @@ def denoising_value_valid(dnv):
else:
return StableDiffusionXLPipelineOutput(images=latents)

# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)

image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
Expand Down
Loading

0 comments on commit c3e3a1e

Please sign in to comment.