Skip to content

Commit

Permalink
fix all issues with cvivit training, making sure it can be trained on…
Browse files Browse the repository at this point in the history
… either images or video, update readme
  • Loading branch information
lucidrains committed Nov 30, 2022
1 parent 134b0a3 commit e949742
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ C-ViViT

```python
import torch
from phenaki_pytorch import CViViT
from phenaki_pytorch import CViViT, CViViTTrainer

cvivit = CViViT(
dim = 512,
Expand All @@ -32,10 +32,18 @@ cvivit = CViViT(
heads = 8
).cuda()

video = torch.randn(1, 3, 17, 256, 256).cuda() # (batch, channels, frames + 1 leading frame, image height, image width)
trainer = CViViTTrainer(
cvivit,
folder = '/home/phil/dl/phenaki-pytorch',
batch_size = 4,
grad_accum_every = 4,
train_on_images = False, # you can train on images first, before fine tuning on video, for sample efficiency
use_ema = False, # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
num_train_steps = 10000
)

trainer.train() # reconstructions and checkpoints will be saved periodically to ./results

loss = cvivit(video)
loss.backward()
```

Phenaki
Expand Down
9 changes: 7 additions & 2 deletions phenaki_pytorch/cvivit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
batch_size,
folder,
train_on_images = False,
num_frames = 17,
lr = 3e-4,
grad_accum_every = 1,
wd = 0.,
Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
if train_on_images:
self.ds = ImageDataset(folder, image_size)
else:
self.ds = VideoDataset(folder, image_size)
self.ds = VideoDataset(folder, image_size, num_frames = num_frames)

# split for validation

Expand Down Expand Up @@ -265,7 +266,11 @@ def train_step(self):
for model, filename in vaes_to_evaluate:
model.eval()

imgs = next(self.valid_dl_iter)
valid_data = next(self.valid_dl_iter)

# for now, only save reconstructed images
imgs = valid_data[:, :, 0] if valid_data.ndim == 5 else valid_data

imgs = imgs.to(device)

recons = model(imgs, return_recons_only = True)
Expand Down
22 changes: 11 additions & 11 deletions phenaki_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def exists(val):
def identity(t, *args, **kwargs):
return t

def pair(val):
return val if isinstance(val, tuple) else (val, val)

def cast_num_frames(t, *, frames):
f = t.shape[1]

Expand Down Expand Up @@ -133,12 +136,12 @@ def video_to_tensor(
while check:
check, frame = video.read()

if exists(crop_size):
frame = crop_center(frame, crop_size, crop_size)

if not check:
continue

if exists(crop_size):
frame = crop_center(frame, *pair(crop_size))

frames.append(rearrange(frame, '... -> 1 ...'))

frames = np.array(np.concatenate(frames[:-1], axis = 0)) # convert list of frames to numpy array
Expand Down Expand Up @@ -180,13 +183,10 @@ def crop_center(
cropx, # Length of the final image in the x direction.
cropy # Length of the final image in the y direction.
) -> torch.Tensor:
try:
y, x, c = img.shape
startx = x // 2 - cropx // 2
starty = y // 2 - cropy // 2
return img[starty:(starty + cropy), startx:(startx + cropx), :]
except:
pass
y, x, c = img.shape
startx = x // 2 - cropx // 2
starty = y // 2 - cropy // 2
return img[starty:(starty + cropy), startx:(startx + cropx), :]

# video dataset

Expand All @@ -196,7 +196,7 @@ def __init__(
folder,
image_size,
channels = 3,
num_frames = 16,
num_frames = 17,
horizontal_flip = False,
force_num_frames = True,
exts = ['gif', 'mp4']
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.41',
version = '0.0.42',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e949742

Please sign in to comment.