Skip to content

Commit

Permalink
address #21
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 28, 2022
1 parent cbc2b5b commit 25105be
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from phenaki_pytorch import CViViT
cvivit = CViViT(
dim = 512,
codebook_size = 5000,
image_size = 256,
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
Expand Down
6 changes: 4 additions & 2 deletions phenaki_pytorch/cvivit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 500,
ema_update_every = 10,
ema_update_after_step = 0,
ema_update_every = 1,
apply_grad_penalty_every = 4,
accelerate_kwargs: dict = dict()
):
Expand All @@ -86,6 +86,7 @@ def __init__(
self.vae = vae

if self.is_main:
print('ema')
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

self.register_buffer('steps', torch.Tensor([0]))
Expand Down Expand Up @@ -250,6 +251,7 @@ def train_step(self):
# update exponential moving averaged generator

if self.is_main:
print('hmmmm')
self.ema_vae.update()

# sample results every so often
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.38',
version = '0.0.39',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,7 +21,7 @@
'accelerate',
'beartype',
'einops>=0.6',
'ema-pytorch',
'ema-pytorch>=0.1.0',
'opencv-python',
'pillow',
'numpy',
Expand Down

0 comments on commit 25105be

Please sign in to comment.