Skip to content

Commit

Permalink
batch sdxl + sd1 + compute_clip_text_embedding
Browse files Browse the repository at this point in the history
Co-authored-by: Cédric Deltheil <[email protected]>
  • Loading branch information
2 people authored and rodSiry committed Feb 28, 2024
1 parent d6c57bd commit 2c91aab
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,22 @@ def __init__(
dtype=dtype,
)

def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
text: The prompt to compute the CLIP text embedding of.
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
text = [text] if isinstance(text, str) else text
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
negative_embedding = self.clip_text_encoder(negative_text)

negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
return torch.cat((negative_embedding, conditional_embedding))

def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
"""Set the various context parameters required by the U-Net model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,23 @@ def __init__(
dtype=dtype,
)

def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
def compute_clip_text_embedding(
self, text: str | list[str], negative_text: str | list[str] = ""
) -> tuple[Tensor, Tensor]:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
text: The prompt to compute the CLIP text embedding of.
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
)

# TODO: when negative_text is None, use zero tensor?
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
text = [text] if isinstance(text, str) else text
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)

return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType
from torch import Tensor, cat, device as Device, dtype as DType, split

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
Expand Down Expand Up @@ -40,21 +40,22 @@ def __init__(
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}

def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)

@property
def tokenizer(self) -> CLIPTokenizer:
return self.ensure_find(CLIPTokenizer)

def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
for str_tokens in split(tokens, 1):
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
end_of_text_index.append(cast(int, position))

def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
def pool(self, x: Float[Tensor, "batch 77 1280"]) -> Float[Tensor, "batch 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
assert len(end_of_text_index) == x.shape[0], "End of text index not found."
return cat([x[i : i + 1, end_of_text_index[i], :] for i in range(x.shape[0])], dim=0)


class DoubleTextEncoder(fl.Chain):
Expand All @@ -75,7 +76,7 @@ def __init__(
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
tep.inject(self.layer("Parallel", fl.Parallel))

def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 2048"], Float[Tensor, "batch 1280"]]:
return super().__call__(text)

def concatenate_embeddings(
Expand Down
117 changes: 116 additions & 1 deletion tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,60 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)


@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std

prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"

clip_text_embedding_b2 = sd15.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)

step = sd15.steps[0]

manual_seed(2)
rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device)

x_b2 = sd15(
rand_b2,
step=step,
clip_text_embedding=clip_text_embedding_b2,
condition_scale=7.5,
)

assert x_b2.shape == (2, 4, 64, 64)

rand_1 = rand_b2[0:1]
clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1])
x_1 = sd15(
rand_1,
step=step,
clip_text_embedding=clip_text_embedding_1,
condition_scale=7.5,
)

rand_2 = rand_b2[1:2]
clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2])
x_2 = sd15(
rand_2,
step=step,
clip_text_embedding=clip_text_embedding_2,
condition_scale=7.5,
)

# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"


@no_grad()
def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
Expand Down Expand Up @@ -836,7 +890,6 @@ def test_diffusion_std_random_init_float16(
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)


Expand Down Expand Up @@ -1265,6 +1318,68 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)


@no_grad()
def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None:
sdxl = sdxl_ddim

prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = "a cute dog"
negative_prompt2 = "lowres, bad anatomy, bad hands"

clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)

time_ids = sdxl.default_time_ids
time_ids_b2 = sdxl.default_time_ids.repeat(2, 1)

manual_seed(seed=2)
x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
x_1 = x_b2[0:1]
x_2 = x_b2[1:2]

x_b2 = sdxl(
x_b2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_b2,
pooled_text_embedding=pooled_text_embedding_b2,
time_ids=time_ids_b2,
)

clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding(
text=prompt1, negative_text=negative_prompt1
)

x_1 = sdxl(
x_1,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_1,
pooled_text_embedding=pooled_text_embedding_1,
time_ids=time_ids,
)

clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding(
text=prompt2, negative_text=negative_prompt2
)

x_2 = sdxl(
x_2,
step=sdxl.steps[0],
clip_text_embedding=clip_text_embedding_2,
pooled_text_embedding=pooled_text_embedding_2,
time_ids=time_ids,
)

# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
assert torch.allclose(
x_b2[0], x_1[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
assert torch.allclose(
x_b2[1], x_2[0], atol=5e-3, rtol=0
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"


@no_grad()
def test_diffusion_sdxl_lora(
sdxl_ddim: StableDiffusion_XL,
Expand Down
21 changes: 21 additions & 0 deletions tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,24 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:

assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)


@no_grad()
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt1 = "A photo of a pizza."
prompt2 = "A giant duck."

double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])

assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
assert pooled_embedding_b2.shape == torch.Size([2, 1280])

double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)

assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)

assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)

0 comments on commit 2c91aab

Please sign in to comment.