Skip to content

Commit

Permalink
just make sure unconditional and conditional video sampling during tr…
Browse files Browse the repository at this point in the history
…aining is working properly
  • Loading branch information
lucidrains committed Dec 2, 2022
1 parent 31b94f9 commit 666bbba
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 18 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ trainer = PhenakiTrainer(
batch_size = 4,
grad_accum_every = 4,
train_on_images = False, # if your mock dataset above return (images, caption) pairs, set this to True
dataset = dataset # pass in your dataset here
dataset = dataset, # pass in your dataset here
sample_texts = [ # list of captions for sampling
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]
)

trainer.train()
Expand Down
2 changes: 1 addition & 1 deletion phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def sample(

# get video token ids

shape = (1, num_tokens)
shape = (batch_size, num_tokens)

video_token_ids = torch.full(shape, self.mask_id, device = device)
mask = torch.ones(shape, device = device, dtype = torch.bool)
Expand Down
117 changes: 102 additions & 15 deletions phenaki_pytorch/phenaki_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math
import copy
from pathlib import Path
from random import random
from random import random, choices
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count

from beartype import beartype
from typing import Optional, List, Iterable

import torch
from torch import nn, einsum
Expand Down Expand Up @@ -68,6 +69,59 @@ def elements_to_device_if_tensor(arr, device):
output.append(el)
return output

def split_iterable(it, split_size):
accum = []
for ind in range(math.ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum

def split(t, split_size = None):
if not exists(split_size):
return t

if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)

if isinstance(t, Iterable):
return split_iterable(t, split_size)

return TypeError

def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None

def split_args_and_kwargs(*args, batch_size = None, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)

if not exists(batch_size):
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)

split_size = default(split_size, batch_size)
num_chunks = math.ceil(batch_size / split_size)

dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len

split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))

for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)

def simple_slugify(text, max_length = 255):
return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]

# trainer class

@beartype
Expand Down Expand Up @@ -96,6 +150,7 @@ def __init__(
fp16 = False,
split_batches = True,
convert_image_to = None,
sample_texts: Optional[List[str]] = None,
dataset = None,
dataset_fields = ('videos', 'texts', 'video_frame_masks')
):
Expand All @@ -105,6 +160,8 @@ def __init__(

assert exists(cvivit)

assert maskgit.unconditional or exists(sample_texts), 'if maskgit is to be trained text conditioned, `sample_texts` List[str] must be given'

self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
Expand All @@ -116,6 +173,9 @@ def __init__(

assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.sample_texts = sample_texts
self.unconditional = maskgit.unconditional

self.save_and_sample_every = save_and_sample_every

self.batch_size = batch_size
Expand Down Expand Up @@ -162,7 +222,7 @@ def __init__(
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

self.results_folder = Path(results_folder)
assert self.results_folder.exists()
self.results_folder.mkdir(parents = True, exist_ok = True)

def print(self, msg):
self.accelerator.print(msg)
Expand Down Expand Up @@ -245,27 +305,54 @@ def train_step(self):
self.model.eval()
milestone = self.step // self.save_and_sample_every

if not self.train_on_images: # sample videos as gifs
with torch.no_grad():
groups = num_to_groups(self.num_samples, self.batch_size)
sampled_videos = [self.model.sample(num_frames = self.sample_num_frames, batch_size = b) for b in groups]
sampled_videos = torch.cat(sampled_videos, dim = 0)
# whether to pass in texts or not

sample_kwargs = dict()

for ind, video_tensor in enumerate(sampled_videos.unbind(dim = 0)):
video_tensor_to_gif(video_tensor, str(self.results_folder / f'{ind}.gif'))
if not self.unconditional:
texts = choices(self.sample_texts, k = self.num_samples)
else:
nrows = int(math.sqrt(self.num_samples))
texts = (None,) * self.num_samples

sample_kwargs = {'texts': texts}

# method to call

method_name = 'sample_images' if self.train_on_images else 'sample'
sample_method = getattr(self.model, method_name)

# evaluate in groups, splitting the kwargs appropriately

with torch.no_grad():
groups = num_to_groups(self.num_samples, self.batch_size)
sampled_images = [self.model.sample_images(batch_size = b) for b in groups]
sampled_images = torch.cat(sampled_images, dim = 0)
with torch.no_grad():
groups = num_to_groups(self.num_samples, self.batch_size)
args_kwargs_iter = split_args_and_kwargs(batch_size = self.num_samples, split_size = self.batch_size, **sample_kwargs)

sampled_images = sampled_images.detach().cpu().float().clamp(0., 1.)
all_sampled = []
for group_batch_size, (_, (_, kwargs)) in zip(groups, args_kwargs_iter):
_kwargs = kwargs if not self.unconditional else dict()
sampled = sample_method(num_frames = self.sample_num_frames, batch_size = group_batch_size, **_kwargs)
all_sampled.append(sampled)

# save video and images differently

if not self.train_on_images:
sampled_videos = torch.cat(all_sampled, dim = 0)
milestone_folder = self.results_folder / f'videos.{milestone}'
milestone_folder.mkdir(parents = True, exist_ok = True)

for ind, (video_tensor, video_caption) in enumerate(zip(sampled_videos.unbind(dim = 0), texts)):
slugged_video_caption = simple_slugify(video_caption) if exists(video_caption) else str(ind)
video_tensor_to_gif(video_tensor, str(milestone_folder / f'{slugged_video_caption}.gif'))
else:
nrows = int(math.sqrt(self.num_samples))

sampled_images = sampled_videos.detach().cpu().float().clamp(0., 1.)
grid = make_grid(sampled_images, nrow = nrows, normalize = True, value_range = (0, 1))

save_image(grid, str(self.results_folder / f'{milestone}.png'))

# save checkpoints

self.save(milestone)

self.step += 1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.47',
version = '0.0.49',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 666bbba

Please sign in to comment.