Skip to content

Commit

Permalink
(doc/foundationals) add IPAdapter, related docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Feb 2, 2024
1 parent 0643a40 commit 1cbad09
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/reference/foundationals/latent_diffusion.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
::: refiners.foundationals.latent_diffusion.solvers

::: refiners.foundationals.latent_diffusion.lora

::: refiners.foundationals.latent_diffusion.image_prompt
49 changes: 48 additions & 1 deletion src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:


class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
"""Image Prompt adapter for a Stable Diffusion U-Net model.
See [[arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721)
for more details.
"""

# Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH]
_grid_image_encoder: list[CLIPImageEncoderH]
Expand All @@ -343,6 +349,16 @@ def __init__(
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
"""Initialize the adapter.
Args:
target: The target model to adapt.
clip_image_encoder: The CLIP image encoder to use.
image_proj: The image projection to use.
scale: The scale to use for the image prompt.
fine_grained: Whether to use fine-grained image prompt.
weights: The weights of the IPAdapter.
"""
with self.setup_adapter(target):
super().__init__(target)

Expand Down Expand Up @@ -376,6 +392,7 @@ def __init__(

@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
"""The CLIP image encoder of the adapter."""
return self._clip_image_encoder[0]

@property
Expand All @@ -399,6 +416,7 @@ def eject(self) -> None:

@property
def scale(self) -> float:
"""The scale of the adapter."""
return self.sub_adapters[0].scale

@scale.setter
Expand All @@ -411,6 +429,14 @@ def set_scale(self, scale: float) -> None:
cross_attn.scale = scale

def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
"""Set the CLIP image embedding context.
Note:
This is required by `ImageCrossAttention`.
Args:
image_embedding: The CLIP image embedding to set.
"""
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

@overload
Expand All @@ -433,6 +459,16 @@ def compute_clip_image_embedding(
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
"""Compute the CLIP image embedding.
Args:
image_prompt: The image prompt to use.
weights: The scale to use for the image prompt.
concat_batches: Whether to concatenate the batches.
Returns:
The CLIP image embedding.
"""
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
Expand Down Expand Up @@ -478,7 +514,18 @@ def preprocess_image(
mean: list[float] | None = None,
std: list[float] | None = None,
) -> Tensor:
# Default mean and std are parameters from https://github.com/openai/CLIP
"""Preprocess the image.
Note:
The default mean and std are parameters from
https://github.com/openai/CLIP
Args:
image: The image to preprocess.
size: The size to resize the image to.
mean: The mean to use for normalization.
std: The standard deviation to use for normalization.
"""
return normalize(
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@


class SDXLIPAdapter(IPAdapter[SDXLUNet]):
"""Image Prompt adapter for the Stable Diffusion XL U-Net model."""

def __init__(
self,
target: SDXLUNet,
Expand All @@ -16,6 +18,16 @@ def __init__(
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
"""Initialize the adapter.
Args:
target: The SDXLUNet model to adapt.
clip_image_encoder: The CLIP image encoder to use.
image_proj: The image projection to use.
scale: The scale to use for the image prompt.
fine_grained: Whether to use fine-grained image prompt.
weights: The weights of the IPAdapter.
"""
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)

if image_proj is None:
Expand Down

0 comments on commit 1cbad09

Please sign in to comment.