From c7713179fdd3a353d464cedaf6206659cc10f3c7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 29 Jul 2024 13:32:14 -0700 Subject: [PATCH] address https://github.com/lucidrains/phenaki-pytorch/issues/42 --- README.md | 4 ++-- phenaki_pytorch/cvivit.py | 23 ++++++++++++++++++----- setup.py | 2 +- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 03f1198..07d4471 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ cvivit = CViViT( ) maskgit = MaskGit( - num_tokens = 5000, + num_tokens = 65536, max_seq_len = 1024, dim = 512, dim_context = 768, @@ -173,7 +173,7 @@ maskgit = MaskGit( # (1) define the critic critic = TokenCritic( - num_tokens = 5000, + num_tokens = 65536, max_seq_len = 1024, dim = 512, dim_context = 768, diff --git a/phenaki_pytorch/cvivit.py b/phenaki_pytorch/cvivit.py index 4560752..187d8a1 100644 --- a/phenaki_pytorch/cvivit.py +++ b/phenaki_pytorch/cvivit.py @@ -283,18 +283,31 @@ def __init__( nn.LayerNorm(dim) ) - transformer_kwargs = dict( + spatial_transformer_kwargs = dict( dim = dim, dim_head = dim_head, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, + causal = False, + peg = False, + ) + + # only temporal transformers have PEG and are causal + + temporal_transformer_kwargs = dict( + dim = dim, + dim_head = dim_head, + heads = heads, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout, + causal = True, peg = True, peg_causal = True, ) - self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs) - self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs) + self.enc_spatial_transformer = Transformer(depth = spatial_depth, **spatial_transformer_kwargs) + self.enc_temporal_transformer = Transformer(depth = temporal_depth, **temporal_transformer_kwargs) # offer look up free quantization # https://arxiv.org/abs/2310.05737 @@ -306,8 +319,8 @@ def __init__( else: self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True) - self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs) - self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs) + self.dec_spatial_transformer = Transformer(depth = spatial_depth, **spatial_transformer_kwargs) + self.dec_temporal_transformer = Transformer(depth = temporal_depth, **temporal_transformer_kwargs) self.to_pixels_first_frame = nn.Sequential( nn.Linear(dim, channels * patch_width * patch_height), diff --git a/setup.py b/setup.py index 12c81f8..732c6b1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.4.2', + version = '0.5.0', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',