Skip to content

Commit

Permalink
clip text, lda encode batch inputs
Browse files Browse the repository at this point in the history
* text_encoder([str1, str2])
* lda decode_latents/encode_image image_to_latent/latent_to_image
* images_to_tensor, tensor_to_images
---------
Co-authored-by: Cédric Deltheil <[email protected]>
  • Loading branch information
piercus authored Feb 1, 2024
1 parent df45b92 commit 21d3a23
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 48 deletions.
11 changes: 11 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
Expand Down Expand Up @@ -113,6 +114,12 @@ def default_sigma(kernel_size: int) -> float:
return tensor


def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])


def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
"""
Convert a PIL Image to a Tensor.
Expand All @@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
return image_tensor.unsqueeze(0)


def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore


def tensor_to_image(tensor: Tensor) -> Image.Image:
"""
Convert a Tensor to a PIL Image.
Expand Down
14 changes: 11 additions & 3 deletions src/refiners/foundationals/clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from itertools import islice
from pathlib import Path

from torch import Tensor, tensor
from torch import Tensor, cat, tensor

import refiners.fluxion.layers as fl
from refiners.fluxion import pad
Expand Down Expand Up @@ -51,11 +51,19 @@ def __init__(
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id

def forward(self, text: str) -> Tensor:
def forward(self, text: str | list[str]) -> Tensor:
if isinstance(text, str):
return self.tokenize_str(text)
else:
assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}"
return cat([self.tokenize_str(txt) for txt in text])

def tokenize_str(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)

assert (
tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)

@lru_cache()
Expand Down
21 changes: 17 additions & 4 deletions src/refiners/foundationals/latent_diffusion/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Sum,
Upsample,
)
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.fluxion.utils import images_to_tensor, tensor_to_images


class Resnet(Sum):
Expand Down Expand Up @@ -210,12 +210,25 @@ def decode(self, x: Tensor) -> Tensor:
x = decoder(x / self.encoder_scale)
return x

def encode_image(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])

def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)

# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)

def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

return self.latents_to_images(x)[0]

def latents_to_images(self, x: Tensor) -> list[Image.Image]:
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_image(x)
return tensor_to_images(x)
2 changes: 1 addition & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init_latents(
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
return self.solver.add_noise(
x=encoded_image,
noise=noise,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ def device(self) -> Device:
def dtype(self) -> DType:
return self.ldm.dtype

# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
return self.latents_to_image(x=x)

def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)

def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)

@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
Expand Down Expand Up @@ -202,7 +202,7 @@ def compute_evaluation(self) -> None:
step=step,
clip_text_embedding=clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)

Expand Down
Loading

0 comments on commit 21d3a23

Please sign in to comment.