Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List inputs in lda encode/decode, and clip text_encoder. #213

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
Expand Down Expand Up @@ -113,6 +114,12 @@ def default_sigma(kernel_size: int) -> float:
return tensor


def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])


def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
"""
Convert a PIL Image to a Tensor.
Expand All @@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
return image_tensor.unsqueeze(0)


def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore


def tensor_to_image(tensor: Tensor) -> Image.Image:
"""
Convert a Tensor to a PIL Image.
Expand Down
14 changes: 11 additions & 3 deletions src/refiners/foundationals/clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from itertools import islice
from pathlib import Path

from torch import Tensor, tensor
from torch import Tensor, cat, tensor

import refiners.fluxion.layers as fl
from refiners.fluxion import pad
Expand Down Expand Up @@ -51,11 +51,19 @@ def __init__(
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id

def forward(self, text: str) -> Tensor:
def forward(self, text: str | list[str]) -> Tensor:
if isinstance(text, str):
return self.tokenize_str(text)
else:
assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}"
return cat([self.tokenize_str(txt) for txt in text])

def tokenize_str(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)

assert (
tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)

@lru_cache()
Expand Down
21 changes: 17 additions & 4 deletions src/refiners/foundationals/latent_diffusion/auto_encoder.py
deltheil marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Sum,
Upsample,
)
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.fluxion.utils import images_to_tensor, tensor_to_images


class Resnet(Sum):
Expand Down Expand Up @@ -210,12 +210,25 @@ def decode(self, x: Tensor) -> Tensor:
x = decoder(x / self.encoder_scale)
return x

def encode_image(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])

def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)

# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)

def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

return self.latents_to_images(x)[0]

def latents_to_images(self, x: Tensor) -> list[Image.Image]:
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_image(x)
return tensor_to_images(x)
2 changes: 1 addition & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init_latents(
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
return self.solver.add_noise(
x=encoded_image,
noise=noise,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ def device(self) -> Device:
def dtype(self) -> DType:
return self.ldm.dtype

# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
return self.latents_to_image(x=x)

def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)

def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)

@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
Expand Down Expand Up @@ -202,7 +202,7 @@ def compute_evaluation(self) -> None:
step=step,
clip_text_embedding=clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)

Expand Down
Loading