Skip to content

Commit

Permalink
add position generation at each layer of the transformer for cvivit, …
Browse files Browse the repository at this point in the history
…maskgit, and critic
  • Loading branch information
lucidrains committed Dec 7, 2022
1 parent f25ae7f commit a81a47c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 37 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ trainer.train()
- [x] add depthwise-convs to cvivit for position generating
- [x] some basic video manipulation code, allow for sampled tensor to be saved as gif
- [x] basic critic training code
- [x] add position generating dsconv to maskgit too

- [ ] get some basic critic sampling code, show comparison of with and without critic
- [ ] add position generating dsconv to maskgit too
- [ ] add all top of the line research for stabilizing transformers training
- [ ] bring in concatenative token shift (temporal dimension)
- [ ] add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here
Expand All @@ -452,7 +452,6 @@ trainer.train()
- [ ] test maskgit + critic alone on oxford flowers dataset
- [ ] support rectangular sized videos
- [ ] add flash attention as an option for all transformers and cite @tridao
- [ ] abstract out text conditioning module into own package, and take care of audiolm-pytorch at the same time

## Citations

Expand Down
45 changes: 44 additions & 1 deletion phenaki_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch.nn.functional as F
from torch import nn, einsum

from beartype import beartype
from typing import Tuple

from einops import rearrange, repeat

# helpers
Expand Down Expand Up @@ -32,6 +35,38 @@ def FeedForward(dim, mult = 4):
nn.Linear(inner_dim, dim, bias = False)
)

# PEG - position generating module

class PEG(nn.Module):
def __init__(self, dim, causal = False):
super().__init__()
self.causal = causal
self.dsconv = nn.Conv3d(dim, dim, 3, groups = dim)

@beartype
def forward(self, x, shape: Tuple[int, int, int, int] = None):
needs_shape = x.ndim == 3
assert not (needs_shape and not exists(shape))

orig_shape = x.shape

if needs_shape:
x = x.reshape(*shape, -1)

x = rearrange(x, 'b ... d -> b d ...')

frame_padding = (2, 0) if self.causal else (1, 1)

x = F.pad(x, (1, 1, 1, 1, *frame_padding), value = 0.)
x = self.dsconv(x)

x = rearrange(x, 'b d ... -> b ... d')

if needs_shape:
x = rearrange(x, 'b ... d -> b (...) d')

return x.reshape(orig_shape)

# attention

class Attention(nn.Module):
Expand Down Expand Up @@ -224,6 +259,8 @@ def __init__(
dim_head = 64,
heads = 8,
ff_mult = 4,
peg = False,
peg_causal = False,
attn_num_null_kv = 2,
has_cross_attn = False
):
Expand All @@ -232,23 +269,29 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
PEG(dim = dim, causal = peg_causal) if peg else None,
Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal),
Attention(dim = dim, dim_head = dim_head, dim_context = dim_context, heads = heads, causal = False, num_null_kv = attn_num_null_kv) if has_cross_attn else None,
FeedForward(dim = dim, mult = ff_mult)
]))

self.norm_out = nn.LayerNorm(dim)

@beartype
def forward(
self,
x,
video_shape: Tuple[int, int, int, int] = None,
attn_bias = None,
context = None,
self_attn_mask = None,
cross_attn_context_mask = None
):

for self_attn, cross_attn, ff in self.layers:
for peg, self_attn, cross_attn, ff in self.layers:
if exists(peg):
x = peg(x, shape = video_shape) + x

x = self_attn(x, attn_bias = attn_bias, mask = self_attn_mask) + x

if exists(cross_attn) and exists(context):
Expand Down
45 changes: 12 additions & 33 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,6 @@ def grad_layer_wrt_loss(loss, layer):
retain_graph = True
)[0].detach()

# PEG - position generating module

class PEG(nn.Module):
def __init__(self, dim):
super().__init__()
self.dsconv = nn.Conv3d(dim, dim, 3, groups = dim)

def forward(self, x):
x = rearrange(x, 'b ... d -> b d ...')
res = x.clone() # residual

x = F.pad(x, (1, 1, 1, 1, 2, 0), value = 0.)
x = self.dsconv(x)

out = x + res
out = rearrange(out, 'b d ... -> b ... d')
return out

# discriminator

class DiscriminatorBlock(nn.Module):
Expand Down Expand Up @@ -245,7 +227,8 @@ def __init__(

self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim, heads = heads)

assert (self.image_size[0] % patch_height) == 0 and (self.image_size[1] % patch_width) == 0
image_height, image_width = self.image_size
assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0

self.to_patch_emb_first_frame = nn.Sequential(
Rearrange('b c 1 (h p1) (w p2) -> b 1 h w (c p1 p2)', p1 = patch_height, p2 = patch_width),
Expand All @@ -257,17 +240,13 @@ def __init__(
nn.Linear(channels * patch_width * patch_height * temporal_patch_size, dim)
)

self.enc_peg = PEG(dim = dim)

self.enc_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads)
self.enc_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True)
self.enc_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads, peg = True, peg_causal = True)
self.enc_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True, peg = True, peg_causal = True)

self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)

self.dec_peg = PEG(dim = dim)

self.dec_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads)
self.dec_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True)
self.dec_spatial_transformer = Transformer(dim = dim, depth = spatial_depth, dim_head = dim_head, heads = heads, peg = True, peg_causal = True)
self.dec_temporal_transformer = Transformer(dim = dim, depth = temporal_depth, dim_head = dim_head, heads = heads, causal = True, peg = True, peg_causal = True)

self.to_pixels_first_frame = nn.Sequential(
nn.Linear(dim, channels * patch_width * patch_height),
Expand Down Expand Up @@ -394,21 +373,21 @@ def encode(
b = tokens.shape[0]
h, w = self.patch_height_width

tokens = self.enc_peg(tokens)
video_shape = tuple(tokens.shape[:-1])

tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d')

attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)

tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias)
tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)

tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)

# encode - temporal

tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')

tokens = self.enc_temporal_transformer(tokens)
tokens = self.enc_temporal_transformer(tokens, video_shape = video_shape)

tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)

Expand All @@ -424,13 +403,13 @@ def decode(
if tokens.ndim == 3:
tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w)

tokens = self.dec_peg(tokens)
video_shape = tuple(tokens.shape[:-1])

# decode - temporal

tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')

tokens = self.dec_temporal_transformer(tokens)
tokens = self.dec_temporal_transformer(tokens, video_shape = video_shape)

tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)

Expand All @@ -440,7 +419,7 @@ def decode(

attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)

tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias)
tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)

tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)

Expand Down
8 changes: 8 additions & 0 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
has_cross_attn = not self.unconditional,
dim_head = dim_head,
heads = heads,
peg = True,
**kwargs
)

Expand Down Expand Up @@ -177,13 +178,16 @@ def forward(
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

video_shape = (b, *video_patch_shape)

x = self.token_emb(x)
x = self.pos_emb(torch.arange(n, device = device)) + x

x = x * self.gradient_shrink_alpha + x.detach() * (1 - self.gradient_shrink_alpha)

x = self.transformer(
x,
video_shape = video_shape,
attn_bias = rel_pos_bias,
self_attn_mask = video_mask,
cross_attn_context_mask = text_mask,
Expand Down Expand Up @@ -251,6 +255,7 @@ def __init__(

self.transformer = Transformer(
dim = dim,
peg = True,
has_cross_attn = has_cross_attn,
**kwargs
)
Expand Down Expand Up @@ -282,6 +287,8 @@ def forward(
context = None,
**kwargs
):
video_shape = x.shape

x = rearrange(x, 'b ... -> b (...)')
b, n, device = *x.shape, x.device

Expand All @@ -297,6 +304,7 @@ def forward(

x = self.transformer(
x,
video_shape = video_shape,
context = context,
cross_attn_context_mask = text_mask,
**kwargs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.55',
version = '0.0.56',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a81a47c

Please sign in to comment.