Skip to content

Commit

Permalink
just duplicate phenaki trainer for token critic training code for now
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 7, 2022
1 parent d60a47a commit f25ae7f
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 68 deletions.
200 changes: 142 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord

<a href="https://www.youtube.com/watch?v=RYLomvaPWa4">AI Coffeebreak explanation</a>

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work on cutting edge artificial intelligence research

- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing transformers and accelerate library

- <a href="https://github.com/gmegh">Guillem</a> for his ongoing contributions

- You? If you are a great machine learning engineer and / or researcher, feel free to contribute to the frontier of open source generative AI

## Install

```bash
Expand Down Expand Up @@ -132,6 +142,74 @@ entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

That's it!

## Token Critic

A <a href="https://arxiv.org/abs/2209.04439">new paper</a> suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below

```python
import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, PhenakiCritic

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
image_size = (256, 128),
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 64,
heads = 8
)

maskgit = MaskGit(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)

critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6
)

critic_trainer = PhenakiCritic(
maskgit = maskgit,
critic = critic,
cvivit = cvivit
).cuda()

texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]

videos = torch.randn(3, 3, 3, 256, 128).cuda() # (batch, channels, frames, height, width)

loss = critic_trainer(videos = videos, texts = texts)
loss.backward()
```

Then just pass the critic to `Phenaki`

```python

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit,
critic = critic
).cuda()

```

Now your generations should be greatly improved (but who knows, since this is only a month old research)

## Phenaki Trainer (wip)

This repository will also endeavor to allow the researcher to train on text-to-image and then text-to-video. Similarly, for unconditional training, the researcher should be able to first train on images and then fine tune on video. Below is an example for text-to-video
Expand Down Expand Up @@ -209,12 +287,11 @@ trainer = PhenakiTrainer(
trainer.train()
```

Unconditional is as follows

ex. unconditional images and video training
Token critic training is similarly

```python
import torch
from torch.utils.data import Dataset
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
Expand All @@ -240,36 +317,70 @@ maskgit = MaskGit(
unconditional = False
)

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit
critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6
)

phenaki_critic = PhenakiCritic(
maskgit = maskgit,
critic = critic,
cvivit = cvivit
).cuda()

# pass in the folder to images or video
# mock text video dataset
# you will have to extend your own, and return the (<video tensor>, <caption>) tuple

trainer = PhenakiTrainer(
phenaki = phenaki,
class MockTextVideoDataset(Dataset):
def __init__(
self,
length = 100,
image_size = 256,
num_frames = 17
):
super().__init__()
self.num_frames = num_frames
self.image_size = image_size
self.len = length

def __len__(self):
return self.len

def __getitem__(self, idx):
video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
caption = 'video caption'
return video, caption

dataset = MockTextVideoDataset()

# pass in the dataset

trainer = PhenakiCriticTrainer(
phenaki_critic = phenaki_critic,
batch_size = 4,
grad_accum_every = 4,
train_on_images = True, # for sake of example, bottom is folder of images
dataset = '/path/to/images/or/video'
train_on_images = False, # if your mock dataset above return (images, caption) pairs, set this to True
dataset = dataset # pass in your dataset here
)

trainer.train()
```

## Token Critic
Unconditional is as follows

A <a href="https://arxiv.org/abs/2209.04439">new paper</a> suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below
ex. unconditional images and video training

```python
import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, PhenakiCritic
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
image_size = (256, 128),
image_size = 256,
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
Expand All @@ -278,63 +389,34 @@ cvivit = CViViT(
heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
unconditional = False
)

critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6
)

critic_trainer = PhenakiCritic(
maskgit = maskgit,
critic = critic,
cvivit = cvivit
).cuda()

texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]

videos = torch.randn(3, 3, 3, 256, 128).cuda() # (batch, channels, frames, height, width)

loss = critic_trainer(videos = videos, texts = texts)
loss.backward()
```

Then just pass the critic to `Phenaki`

```python

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit,
critic = critic
maskgit = maskgit
).cuda()

```

Now your generations should be greatly improved (but who knows, since this is only a month old research)

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work on cutting edge artificial intelligence research

- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing transformers and accelerate library
# pass in the folder to images or video

- <a href="https://github.com/gmegh">Guillem</a> for his ongoing contributions
trainer = PhenakiTrainer(
phenaki = phenaki,
batch_size = 4,
grad_accum_every = 4,
train_on_images = True, # for sake of example, bottom is folder of images
dataset = '/path/to/images/or/video'
)

- You? If you are a great machine learning engineer and / or researcher, feel free to contribute to the frontier of open source generative AI
trainer.train()
```

## Todo

Expand All @@ -358,7 +440,9 @@ Now your generations should be greatly improved (but who knows, since this is on
- [x] wire up accelerate for multi-gpu training for both c-vivit and maskgit
- [x] add depthwise-convs to cvivit for position generating
- [x] some basic video manipulation code, allow for sampled tensor to be saved as gif
- [x] basic critic training code

- [ ] get some basic critic sampling code, show comparison of with and without critic
- [ ] add position generating dsconv to maskgit too
- [ ] add all top of the line research for stabilizing transformers training
- [ ] bring in concatenative token shift (temporal dimension)
Expand Down
2 changes: 1 addition & 1 deletion phenaki_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, MaskGitTrainWrapper, TokenCritic, PhenakiCritic, make_video

from phenaki_pytorch.cvivit_trainer import CViViTTrainer
from phenaki_pytorch.phenaki_trainer import PhenakiTrainer
from phenaki_pytorch.phenaki_trainer import PhenakiTrainer, PhenakiCriticTrainer
11 changes: 5 additions & 6 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)

num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
num_tokens = mask.sum(dim = -1, keepdim = True)
mask_excess = (mask.cumsum(dim = -1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]

rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
rand = torch.rand((batch, seq_len), device = device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim = -1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask = torch.zeros((batch, seq_len + 1), device = device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()

Expand Down Expand Up @@ -283,7 +283,6 @@ def forward(
**kwargs
):
x = rearrange(x, 'b ... -> b (...)')

b, n, device = *x.shape, x.device

if not exists(text_mask):
Expand Down
Loading

0 comments on commit f25ae7f

Please sign in to comment.