From a81a47c0cb71e3d3b4db60078ca41a5844b17482 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 7 Dec 2022 10:12:32 -0800 Subject: [PATCH] add position generation at each layer of the transformer for cvivit, maskgit, and critic --- README.md | 3 +- phenaki_pytorch/attention.py | 45 +++++++++++++++++++++++++++++- phenaki_pytorch/cvivit.py | 45 ++++++++---------------------- phenaki_pytorch/phenaki_pytorch.py | 8 ++++++ setup.py | 2 +- 5 files changed, 66 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 625aad3..b50e630 100644 --- a/README.md +++ b/README.md @@ -441,9 +441,9 @@ trainer.train() - [x] add depthwise-convs to cvivit for position generating - [x] some basic video manipulation code, allow for sampled tensor to be saved as gif - [x] basic critic training code +- [x] add position generating dsconv to maskgit too - [ ] get some basic critic sampling code, show comparison of with and without critic -- [ ] add position generating dsconv to maskgit too - [ ] add all top of the line research for stabilizing transformers training - [ ] bring in concatenative token shift (temporal dimension) - [ ] add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here @@ -452,7 +452,6 @@ trainer.train() - [ ] test maskgit + critic alone on oxford flowers dataset - [ ] support rectangular sized videos - [ ] add flash attention as an option for all transformers and cite @tridao -- [ ] abstract out text conditioning module into own package, and take care of audiolm-pytorch at the same time ## Citations diff --git a/phenaki_pytorch/attention.py b/phenaki_pytorch/attention.py index 74578b4..455508c 100644 --- a/phenaki_pytorch/attention.py +++ b/phenaki_pytorch/attention.py @@ -3,6 +3,9 @@ import torch.nn.functional as F from torch import nn, einsum +from beartype import beartype +from typing import Tuple + from einops import rearrange, repeat # helpers @@ -32,6 +35,38 @@ def FeedForward(dim, mult = 4): nn.Linear(inner_dim, dim, bias = False) ) +# PEG - position generating module + +class PEG(nn.Module): + def __init__(self, dim, causal = False): + super().__init__() + self.causal = causal + self.dsconv = nn.Conv3d(dim, dim, 3, groups = dim) + + @beartype + def forward(self, x, shape: Tuple[int, int, int, int] = None): + needs_shape = x.ndim == 3 + assert not (needs_shape and not exists(shape)) + + orig_shape = x.shape + + if needs_shape: + x = x.reshape(*shape, -1) + + x = rearrange(x, 'b ... d -> b d ...') + + frame_padding = (2, 0) if self.causal else (1, 1) + + x = F.pad(x, (1, 1, 1, 1, *frame_padding), value = 0.) + x = self.dsconv(x) + + x = rearrange(x, 'b d ... -> b ... d') + + if needs_shape: + x = rearrange(x, 'b ... d -> b (...) d') + + return x.reshape(orig_shape) + # attention class Attention(nn.Module): @@ -224,6 +259,8 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, + peg = False, + peg_causal = False, attn_num_null_kv = 2, has_cross_attn = False ): @@ -232,6 +269,7 @@ def __init__( for _ in range(depth): self.layers.append(nn.ModuleList([ + PEG(dim = dim, causal = peg_causal) if peg else None, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal), Attention(dim = dim, dim_head = dim_head, dim_context = dim_context, heads = heads, causal = False, num_null_kv = attn_num_null_kv) if has_cross_attn else None, FeedForward(dim = dim, mult = ff_mult) @@ -239,16 +277,21 @@ def __init__( self.norm_out = nn.LayerNorm(dim) + @beartype def forward( self, x, + video_shape: Tuple[int, int, int, int] = None, attn_bias = None, context = None, self_attn_mask = None, cross_attn_context_mask = None ): - for self_attn, cross_attn, ff in self.layers: + for peg, self_attn, cross_attn, ff in self.layers: + if exists(peg): + x = peg(x, shape = video_shape) + x + x = self_attn(x, attn_bias = attn_bias, mask = self_attn_mask) + x if exists(cross_attn) and exists(context): diff --git a/phenaki_pytorch/cvivit.py b/phenaki_pytorch/cvivit.py index 25f96b9..972b7e2 100644 --- a/phenaki_pytorch/cvivit.py +++ b/phenaki_pytorch/cvivit.py @@ -98,24 +98,6 @@ def grad_layer_wrt_loss(loss, layer): retain_graph = True )[0].detach() -# PEG - position generating module - -class PEG(nn.Module): - def __init__(self, dim): - super().__init__() - self.dsconv = nn.Conv3d(dim, dim, 3, groups = dim) - - def forward(self, x): - x = rearrange(x, 'b ... d -> b d ...') - res = x.clone() # residual - - x = F.pad(x, (1, 1, 1, 1, 2, 0), value = 0.) - x = self.dsconv(x) - - out = x + res - out = rearrange(out, 'b d ... -> b ... d') - return out - # discriminator class DiscriminatorBlock(nn.Module): @@ -245,7 +227,8 @@ def __init__( self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim, heads = heads) - assert (self.image_size[0] % patch_height) == 0 and (self.image_size[1] % patch_width) == 0 + image_height, image_width = self.image_size + assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0 self.to_patch_emb_first_frame = nn.Sequential( Rearrange('b c 1 (h p1) (w p2) -> b 1 h w (c p1 p2)', p1 = patch_height, p2 = patch_width), @@ -257,17 +240,13 @@ def __init__( nn.Linear(channels * patch_width * patch_height * temporal_patch_size, dim) ) - self.enc_peg = PEG(dim = dim) - - self.enc_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads) - self.enc_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True) + self.enc_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads, peg = True, peg_causal = True) + self.enc_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True, peg = True, peg_causal = True) self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True) - self.dec_peg = PEG(dim = dim) - - self.dec_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads) - self.dec_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True) + self.dec_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads, peg = True, peg_causal = True) + self.dec_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True, peg = True, peg_causal = True) self.to_pixels_first_frame = nn.Sequential( nn.Linear(dim, channels * patch_width * patch_height), @@ -394,13 +373,13 @@ def encode( b = tokens.shape[0] h, w = self.patch_height_width - tokens = self.enc_peg(tokens) + video_shape = tuple(tokens.shape[:-1]) tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d') attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device) - tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias) + tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape) tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w) @@ -408,7 +387,7 @@ def encode( tokens = rearrange(tokens, 'b t h w d -> (b h w) t d') - tokens = self.enc_temporal_transformer(tokens) + tokens = self.enc_temporal_transformer(tokens, video_shape = video_shape) tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w) @@ -424,13 +403,13 @@ def decode( if tokens.ndim == 3: tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w) - tokens = self.dec_peg(tokens) + video_shape = tuple(tokens.shape[:-1]) # decode - temporal tokens = rearrange(tokens, 'b t h w d -> (b h w) t d') - tokens = self.dec_temporal_transformer(tokens) + tokens = self.dec_temporal_transformer(tokens, video_shape = video_shape) tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w) @@ -440,7 +419,7 @@ def decode( attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device) - tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias) + tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape) tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w) diff --git a/phenaki_pytorch/phenaki_pytorch.py b/phenaki_pytorch/phenaki_pytorch.py index 127d994..5dfdf1f 100644 --- a/phenaki_pytorch/phenaki_pytorch.py +++ b/phenaki_pytorch/phenaki_pytorch.py @@ -130,6 +130,7 @@ def __init__( has_cross_attn = not self.unconditional, dim_head = dim_head, heads = heads, + peg = True, **kwargs ) @@ -177,6 +178,8 @@ def forward( keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask + video_shape = (b, *video_patch_shape) + x = self.token_emb(x) x = self.pos_emb(torch.arange(n, device = device)) + x @@ -184,6 +187,7 @@ def forward( x = self.transformer( x, + video_shape = video_shape, attn_bias = rel_pos_bias, self_attn_mask = video_mask, cross_attn_context_mask = text_mask, @@ -251,6 +255,7 @@ def __init__( self.transformer = Transformer( dim = dim, + peg = True, has_cross_attn = has_cross_attn, **kwargs ) @@ -282,6 +287,8 @@ def forward( context = None, **kwargs ): + video_shape = x.shape + x = rearrange(x, 'b ... -> b (...)') b, n, device = *x.shape, x.device @@ -297,6 +304,7 @@ def forward( x = self.transformer( x, + video_shape = video_shape, context = context, cross_attn_context_mask = text_mask, **kwargs diff --git a/setup.py b/setup.py index e510967..93b4db3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.0.55', + version = '0.0.56', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',