Skip to content

Commit

Permalink
allow to sample on more than one text, also fix phenaki trainer to sa…
Browse files Browse the repository at this point in the history
…mple videos and images depending on which is being trained on (will need to train phenaki on images with or without text before upgrading to video)
  • Loading branch information
lucidrains committed Dec 1, 2022
1 parent 8fcc2d3 commit b548fa0
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 18 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ loss.backward()

# do the above for many steps, then ...

video = phenaki.sample(text = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3

video_next = phenaki.sample(text = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video

Expand Down
1 change: 1 addition & 0 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
import copy
import math
from functools import wraps
Expand Down
2 changes: 0 additions & 2 deletions phenaki_pytorch/cvivit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ def train_step(self):
for tensor in recons.unbind(dim = 0):
video_tensor_to_gif(tensor, str(sampled_videos_path / f'{filename}.gif'))
else:
nrows = int(sqrt(self.batch_size))

imgs_and_recons = torch.stack((valid_data, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

Expand Down
23 changes: 15 additions & 8 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,32 +501,34 @@ def __init__(
assert cond_drop_prob > 0.
self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

def sample_image(
def sample_images(
self,
*,
text,
texts: Union[List[str], str] = None,
batch_size = 1,
cond_scale = 3.,
starting_temperature = 0.9,
noise_K = 1.
):
single_framed_video = self.sample(
text = text,
texts = texts,
num_frames = 1,
cond_scale = cond_scale,
starting_temperature = starting_temperature,
noise_K = noise_K
)

return rearrange(single_framed_video, 'b c 1 h w -> b c h w')
return rearrange(single_framed_video, '... c 1 h w -> ... c h w')

@eval_decorator
@torch.no_grad()
def sample(
self,
*,
num_frames,
text = None,
texts: Union[List[str], str] = None,
prime_frames = None,
batch_size = 1,
cond_scale = 3.,
starting_temperature = 0.9,
noise_K = 1. # hyperparameter for noising of critic score in section 3.2 of token-critic paper, need to find correct value
Expand Down Expand Up @@ -555,11 +557,16 @@ def sample(

text_embeds = text_mask = None

if exists(text):
if exists(texts):
if isinstance(texts, str):
texts = [texts]

with torch.no_grad():
text_embeds = self.encode_texts([text], output_device = device)
text_embeds = self.encode_texts(texts, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

batch_size = len(texts)

# derive video patch shape

patch_shape = self.cvivit.get_video_patch_shape(num_frames + prime_num_frames, include_first_frame = True)
Expand Down Expand Up @@ -736,7 +743,7 @@ def make_video(
scenes = []

for text, scene_num_frames, next_scene_prime_length in zip(texts, num_frames, prime_lengths):
video = phenaki.sample(text = text, prime_frames = video_prime, num_frames = scene_num_frames)
video = phenaki.sample(texts = text, prime_frames = video_prime, num_frames = scene_num_frames)
scenes.append(video)

video_prime = video[:, :, -next_scene_prime_length:]
Expand Down
29 changes: 24 additions & 5 deletions phenaki_pytorch/phenaki_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torchvision import transforms as T, utils

from torchvision import transforms as T
from torchvision.utils import make_grid, save_image

from einops import rearrange, reduce
from einops.layers.torch import Rearrange
Expand Down Expand Up @@ -120,6 +122,7 @@ def __init__(
dataset_klass = ImageDataset if train_on_images else VideoDataset

self.sample_num_frames = default(sample_num_frames, num_frames)
self.train_on_images = train_on_images

if train_on_images:
self.ds = ImageDataset(folder, self.image_size)
Expand Down Expand Up @@ -223,11 +226,27 @@ def train_step(self):
self.model.eval()
milestone = self.step // self.save_and_sample_every

with torch.no_grad():
batches = num_to_groups(self.num_samples, self.batch_size)
sampled_video = self.model.sample(num_frames = self.sample_num_frames)
if not self.train_on_images: # sample videos as gifs
with torch.no_grad():
groups = num_to_groups(self.num_samples, self.batch_size)
sampled_videos = [self.model.sample(num_frames = self.sample_num_frames, batch_size = b) for b in groups]
sampled_videos = torch.cat(sampled_videos, dim = 0)

for ind, video_tensor in enumerate(sampled_videos.unbind(dim = 0)):
video_tensor_to_gif(video_tensor, str(self.results_folder / f'{ind}.gif'))
else:
nrows = int(math.sqrt(self.num_samples))

with torch.no_grad():
groups = num_to_groups(self.num_samples, self.batch_size)
sampled_images = [self.model.sample_images(batch_size = b) for b in groups]
sampled_images = torch.cat(sampled_images, dim = 0)

sampled_images = sampled_images.detach().cpu().float().clamp(0., 1.)
grid = make_grid(sampled_images, nrow = nrows, normalize = True, value_range = (0, 1))

save_image(grid, str(self.results_folder / f'{milestone}.png'))

video_tensor_to_gif(sampled_video[0], str(self.results_folder / f'{milestone}.gif'))
self.save(milestone)

self.step += 1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.44',
version = '0.0.45',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b548fa0

Please sign in to comment.