Skip to content

Commit

Permalink
batch sdxl + sd1 + compute_clip_text_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 8, 2024
1 parent 6d599d5 commit 2a03655
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from torch import Tensor, cat, device as Device, dtype as DType

from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
dtype=dtype,
)

def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
Expand All @@ -78,10 +78,19 @@ def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Ten
"""
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
negative_embedding = conditional_embedding
else:
if isinstance(text, list) and isinstance(negative_text, list):
assert len(text) == len(
negative_text
), "The length of the text list and negative_text should be the same"

if isinstance(negative_text, str) and isinstance(text, list):
negative_text = [negative_text] * len(text)

negative_embedding = self.clip_text_encoder(negative_text)

negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
return cat((negative_embedding, conditional_embedding))

def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
"""Set the various context parameters required by the U-Net model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
dtype=dtype,
)

def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
def compute_clip_text_embedding(
self, text: str | list[str], negative_text: str | list[str] | None = None
) -> tuple[Tensor, Tensor]:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,22 @@ def __init__(
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}

def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)

@property
def tokenizer(self) -> CLIPTokenizer:
return self.ensure_find(CLIPTokenizer)

def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
for str_tokens in tokens.split(1): # type: ignore
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
end_of_text_index.append(cast(int, position))

def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
def pool(self, x: Float[Tensor, "batch 77 1280"]) -> Float[Tensor, "batch 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
assert len(end_of_text_index) == x.shape[0], "End of text index not found."
return cat([x[i : i + 1, end_of_text_index[i], :] for i in range(x.shape[0])], dim=0)


class DoubleTextEncoder(fl.Chain):
Expand All @@ -75,7 +76,7 @@ def __init__(
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
tep.inject(self.layer("Parallel", fl.Parallel))

def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 2048"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)

def concatenate_embeddings(
Expand Down
96 changes: 95 additions & 1 deletion tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,39 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)


@no_grad()
def test_diffusion_batch2(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std

prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"

clip_text_embedding = sd15.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)

sd15.set_inference_steps(30)

manual_seed(2)
x = torch.randn(2, 4, 64, 64, device=test_device)

for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)

predicted_images = sd15.lda.latents_to_images(x)
assert len(predicted_images) == 2
ensure_similar_images(predicted_images[0], expected_image_std_random_init)


@no_grad()
def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
Expand Down Expand Up @@ -750,7 +783,6 @@ def test_diffusion_std_random_init_float16(
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)


Expand Down Expand Up @@ -1106,6 +1138,68 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)


@no_grad()
def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None:
sdxl = sdxl_ddim

prompt1 = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt1 = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
prompt2 = "professional portrait photo of a boy"
negative_prompt2 = "black and white"

clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)

time_ids = sdxl.default_time_ids
time_ids_b2 = sdxl.default_time_ids.repeat(2, 1)
sdxl.set_inference_steps(40)

manual_seed(seed=2)
x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
x_1 = x_b2[0:1]
x_2 = x_b2[1:2]

x_b2 = sdxl(
x_b2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_b2,
pooled_text_embedding=pooled_text_embedding_b2,
time_ids=time_ids_b2,
)
predicted_image_b2 = sdxl.lda.latents_to_images(x_b2)

clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding(
text=prompt1, negative_text=negative_prompt1
)

x_1 = sdxl(
x_1,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_1,
pooled_text_embedding=pooled_text_embedding_1,
time_ids=time_ids,
)
predicted_image_1 = sdxl.lda.latents_to_image(x_1)

clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding(
text=prompt2, negative_text=negative_prompt2
)

x_2 = sdxl(
x_2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_2,
pooled_text_embedding=pooled_text_embedding_2,
time_ids=time_ids,
)

predicted_image_2 = sdxl.lda.latents_to_image(x_2)

ensure_similar_images(predicted_image_b2[0], predicted_image_1, min_psnr=35, min_ssim=0.98)
ensure_similar_images(predicted_image_b2[1], predicted_image_2, min_psnr=35, min_ssim=0.98)


@no_grad()
def test_diffusion_sdxl_lora(
sdxl_ddim: StableDiffusion_XL,
Expand Down
21 changes: 21 additions & 0 deletions tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,24 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:

assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)


@no_grad()
def test_double_text_encoder_batch2(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."

double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])

assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
assert pooled_embedding_b2.shape == torch.Size([2, 1280])

double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)

assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)

assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)

0 comments on commit 2a03655

Please sign in to comment.