Skip to content

Commit

Permalink
address #42
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 29, 2024
1 parent 79d7e14 commit c771317
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ cvivit = CViViT(
)

maskgit = MaskGit(
num_tokens = 5000,
num_tokens = 65536,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
Expand All @@ -173,7 +173,7 @@ maskgit = MaskGit(
# (1) define the critic

critic = TokenCritic(
num_tokens = 5000,
num_tokens = 65536,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
Expand Down
23 changes: 18 additions & 5 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,31 @@ def __init__(
nn.LayerNorm(dim)
)

transformer_kwargs = dict(
spatial_transformer_kwargs = dict(
dim = dim,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
causal = False,
peg = False,
)

# only temporal transformers have PEG and are causal

temporal_transformer_kwargs = dict(
dim = dim,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
causal = True,
peg = True,
peg_causal = True,
)

self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)
self.enc_spatial_transformer = Transformer(depth = spatial_depth, **spatial_transformer_kwargs)
self.enc_temporal_transformer = Transformer(depth = temporal_depth, **temporal_transformer_kwargs)

# offer look up free quantization
# https://arxiv.org/abs/2310.05737
Expand All @@ -306,8 +319,8 @@ def __init__(
else:
self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)

self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)
self.dec_spatial_transformer = Transformer(depth = spatial_depth, **spatial_transformer_kwargs)
self.dec_temporal_transformer = Transformer(depth = temporal_depth, **temporal_transformer_kwargs)

self.to_pixels_first_frame = nn.Sequential(
nn.Linear(dim, channels * patch_width * patch_height),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.2',
version = '0.5.0',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c771317

Please sign in to comment.