From 192f8b924ba8ebd7b5d2b02422d6b2755e123b1d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 7 Oct 2024 05:59:38 -0700 Subject: [PATCH] fix a few autocast warnings, add new technique for cfg --- README.md | 9 +++++ imagen_pytorch/elucidated_imagen.py | 10 +++-- imagen_pytorch/imagen_pytorch.py | 57 +++++++++++++++++++++++++++-- imagen_pytorch/imagen_video.py | 33 ++++++++++++++++- imagen_pytorch/version.py | 2 +- 5 files changed, 102 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7692d7e..095a142 100644 --- a/README.md +++ b/README.md @@ -947,3 +947,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo note = {under review} } ``` + +```bibtex +@inproceedings{Sadat2024EliminatingOA, + title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, + author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273098845} +} +``` diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py index 99dac1c..73262ea 100644 --- a/imagen_pytorch/elucidated_imagen.py +++ b/imagen_pytorch/elucidated_imagen.py @@ -9,7 +9,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.nn.parallel import DistributedDataParallel import torchvision.transforms as T @@ -565,6 +565,8 @@ def sample( video_frames = None, batch_size = 1, cond_scale = 1., + cfg_remove_parallel_component = True, + cfg_keep_parallel_frac = 0., lowres_sample_noise_level = None, start_at_unet_number = 1, start_image_or_video = None, @@ -583,7 +585,7 @@ def sample( if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' - with autocast(enabled = False): + with autocast('cuda', enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) @@ -724,6 +726,8 @@ def sample( sigma_min = unet_sigma_min, sigma_max = unet_sigma_max, cond_scale = unet_cond_scale, + remove_parallel_component = cfg_remove_parallel_component, + keep_parallel_frac = cfg_keep_parallel_frac, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, dynamic_threshold = dynamic_threshold, @@ -811,7 +815,7 @@ def forward( assert all([*map(len, texts)]), 'text cannot be empty' assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' - with autocast(enabled = False): + with autocast('cuda', enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index bd2918c..2ac7cef 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel from torch import nn, einsum -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.special import expm1 import torchvision.transforms as T @@ -187,6 +187,15 @@ def safe_get_tuple_index(tup, index, default = None): return default return tup[index] +def pack_one_with_inverse(x, pattern): + packed, packed_shape = pack([x], pattern) + + def inverse(x, inverse_pattern = None): + inverse_pattern = default(inverse_pattern, pattern) + return unpack(x, packed_shape, inverse_pattern)[0] + + return packed, inverse + # image normalization functions # ddpms expect images to be in the range of -1 to 1 @@ -206,6 +215,21 @@ def prob_mask_like(shape, prob, device): else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob +# for improved cfg, getting parallel and orthogonal components of cfg update + +def project(x, y): + x, inverse = pack_one_with_inverse(x, 'b *') + y, _ = pack_one_with_inverse(y, 'b *') + + dtype = x.dtype + x, y = x.double(), y.double() + unit = F.normalize(y, dim = -1) + + parallel = (x * unit).sum(dim = -1, keepdim = True) * unit + orthogonal = x - parallel + + return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) + # gaussian diffusion with continuous time helper functions and classes # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py @@ -1511,6 +1535,8 @@ def forward_with_cond_scale( self, *args, cond_scale = 1., + remove_parallel_component = True, + keep_parallel_frac = 0., **kwargs ): logits = self.forward(*args, **kwargs) @@ -1519,7 +1545,14 @@ def forward_with_cond_scale( return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) - return null_logits + (logits - null_logits) * cond_scale + + update = (logits - null_logits) + + if remove_parallel_component: + parallel, orthogonal = project(update, logits) + update = orthogonal + parallel * keep_parallel_frac + + return logits + update * (cond_scale - 1) def forward( self, @@ -2055,6 +2088,8 @@ def p_mean_variance( self_cond = None, lowres_noise_times = None, cond_scale = 1., + cfg_remove_parallel_component = True, + cfg_keep_parallel_frac = 0., model_output = None, t_next = None, pred_objective = 'noise', @@ -2076,6 +2111,8 @@ def p_mean_variance( text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, + remove_parallel_component = cfg_remove_parallel_component, + keep_parallel_frac = cfg_keep_parallel_frac, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times), @@ -2124,6 +2161,8 @@ def p_sample( cond_video_frames = None, post_cond_video_frames = None, cond_scale = 1., + cfg_remove_parallel_component = True, + cfg_keep_parallel_frac = 0., self_cond = None, lowres_cond_img = None, lowres_noise_times = None, @@ -2149,6 +2188,8 @@ def p_sample( text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, + cfg_remove_parallel_component = cfg_remove_parallel_component, + cfg_keep_parallel_frac = cfg_keep_parallel_frac, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = lowres_noise_times, @@ -2185,6 +2226,8 @@ def p_sample_loop( init_images = None, skip_steps = None, cond_scale = 1, + cfg_remove_parallel_component = False, + cfg_keep_parallel_frac = 0., pred_objective = 'noise', dynamic_threshold = True, use_tqdm = True @@ -2260,6 +2303,8 @@ def p_sample_loop( text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, + cfg_remove_parallel_component = cfg_remove_parallel_component, + cfg_keep_parallel_frac = cfg_keep_parallel_frac, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, @@ -2308,6 +2353,8 @@ def sample( skip_steps = None, batch_size = 1, cond_scale = 1., + cfg_remove_parallel_component = True, + cfg_keep_parallel_frac = 0., lowres_sample_noise_level = None, start_at_unet_number = 1, start_image_or_video = None, @@ -2326,7 +2373,7 @@ def sample( if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' - with autocast(enabled = False): + with autocast('cuda', enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) @@ -2469,6 +2516,8 @@ def sample( init_images = unet_init_images, skip_steps = unet_skip_steps, cond_scale = unet_cond_scale, + cfg_remove_parallel_component = cfg_remove_parallel_component, + cfg_keep_parallel_frac = cfg_keep_parallel_frac, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler, @@ -2695,7 +2744,7 @@ def forward( assert all([*map(len, texts)]), 'text cannot be empty' assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' - with autocast(enabled = False): + with autocast('cuda', enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) diff --git a/imagen_pytorch/imagen_video.py b/imagen_pytorch/imagen_video.py index b8ad573..0497645 100644 --- a/imagen_pytorch/imagen_video.py +++ b/imagen_pytorch/imagen_video.py @@ -95,6 +95,15 @@ def pad_tuple_to_length(t, length, fillvalue = None): return t return (*t, *((fillvalue,) * remain_length)) +def pack_one_with_inverse(x, pattern): + packed, packed_shape = pack([x], pattern) + + def inverse(x, inverse_pattern = None): + inverse_pattern = default(inverse_pattern, pattern) + return unpack(x, packed_shape, inverse_pattern)[0] + + return packed, inverse + # helper classes class Identity(nn.Module): @@ -131,6 +140,19 @@ def masked_mean(t, *, dim, mask = None): return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) +def project(x, y): + x, inverse = pack_one_with_inverse(x, 'b *') + y, _ = pack_one_with_inverse(y, 'b *') + + dtype = x.dtype + x, y = x.double(), y.double() + unit = F.normalize(y, dim = -1) + + parallel = (x * unit).sum(dim = -1, keepdim = True) * unit + orthogonal = x - parallel + + return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) + def resize_video_to( video, target_image_size, @@ -1637,6 +1659,8 @@ def forward_with_cond_scale( self, *args, cond_scale = 1., + remove_parallel_component = False, + keep_parallel_frac = 0., **kwargs ): logits = self.forward(*args, **kwargs) @@ -1645,7 +1669,14 @@ def forward_with_cond_scale( return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) - return null_logits + (logits - null_logits) * cond_scale + + update = (logits - null_logits) + + if remove_parallel_component: + parallel, orthogonal = project(update, logits) + update = orthogonal + parallel * keep_parallel_frac + + return logits + update * (cond_scale - 1) def forward( self, diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index afced14..a33997d 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.0.0' +__version__ = '2.1.0'