Skip to content

Commit

Permalink
automatically infer what to pass to Phenaki from the Dataset, given t…
Browse files Browse the repository at this point in the history
…he type of the data element in the tuple, cached on first run
  • Loading branch information
lucidrains committed Dec 6, 2022
1 parent 8149858 commit adc2486
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
79 changes: 68 additions & 11 deletions phenaki_pytorch/phenaki_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from multiprocessing import cpu_count

from beartype import beartype
from typing import Optional, List, Iterable
from beartype.door import is_bearable
from beartype.vale import Is
from typing import Optional, List, Iterable, Tuple
from typing_extensions import Annotated

import torch
from torch import nn, einsum
Expand All @@ -32,6 +35,28 @@

from phenaki_pytorch.data import ImageDataset, VideoDataset, video_tensor_to_gif, DataLoader

# constants

DATASET_FIELD_TYPE_CONFIG = dict(
videos = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.float and t.ndim in {4, 5}]
],
texts = List[str],
video_codebook_ids = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.long]
],
video_frame_mask = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.bool]
],
text_embeds = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.float and t.ndim == 3]
],
)

# helpers functions

def exists(x):
Expand Down Expand Up @@ -122,6 +147,26 @@ def split_args_and_kwargs(*args, batch_size = None, split_size = None, **kwargs)
def simple_slugify(text, max_length = 255):
return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]

def has_duplicates(tup):
counts = dict()
for el in tup:
if el not in counts:
counts[el] = 0
counts[el] += 1
return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
output = []
for el in data:
for name, data_type in config.items():
if is_bearable(el, data_type):
output.append(name)
break
else:
raise TypeError(f'unable to determine type of {data}')

return tuple(output)

# trainer class

@beartype
Expand Down Expand Up @@ -153,14 +198,16 @@ def __init__(
sample_texts_file_path = None, # path to a text file with video captions, delimited by newline
sample_texts: Optional[List[str]] = None,
dataset: Optional[Dataset] = None,
dataset_fields = ('videos', 'texts', 'video_frame_masks')
dataset_fields: Optional[Tuple[str, ...]] = None
):
super().__init__()
maskgit = phenaki.maskgit
cvivit = phenaki.cvivit

assert exists(cvivit), 'cvivit must be present on phenaki'

# define accelerator

self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
Expand All @@ -173,6 +220,18 @@ def __init__(
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.unconditional = maskgit.unconditional

# training related variables

self.batch_size = batch_size
self.grad_accum_every = grad_accum_every

self.max_grad_norm = max_grad_norm

self.train_num_steps = train_num_steps
self.image_size = phenaki.cvivit.image_size

# sampling related variables

self.num_samples = num_samples

self.sample_texts = None
Expand All @@ -190,14 +249,6 @@ def __init__(

self.save_and_sample_every = save_and_sample_every

self.batch_size = batch_size
self.grad_accum_every = grad_accum_every

self.max_grad_norm = max_grad_norm

self.train_num_steps = train_num_steps
self.image_size = phenaki.cvivit.image_size

# dataset and dataloader

dataset_klass = ImageDataset if train_on_images else VideoDataset
Expand Down Expand Up @@ -236,6 +287,12 @@ def __init__(
self.results_folder = Path(results_folder)
self.results_folder.mkdir(parents = True, exist_ok = True)

def data_tuple_to_kwargs(self, data):
if not exists(self.dataset_fields):
self.dataset_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)

return dict(zip(self.dataset_fields, data))

def print(self, msg):
self.accelerator.print(msg)

Expand Down Expand Up @@ -292,7 +349,7 @@ def train_step(self):
for _ in range(self.grad_accum_every):
data = next(self.dl)
data = elements_to_device_if_tensor(data, device)
data_kwargs = dict(zip(self.dataset_fields, data))
data_kwargs = self.data_tuple_to_kwargs(data)

assert not (self.train_on_images and data_kwargs['videos'].ndim != 4), 'you have it set to train on images, but the dataset is not returning tensors of 4 dimensions (batch, channels, height, width)'

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.51',
version = '0.0.52',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit adc2486

Please sign in to comment.