diff --git a/README.md b/README.md
index 7be6833..625aad3 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,16 @@ Please join AI Coffeebreak explanation
+## Appreciation
+
+- Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research
+
+- 🤗 Huggingface for their amazing transformers and accelerate library
+
+- Guillem for his ongoing contributions
+
+- You? If you are a great machine learning engineer and / or researcher, feel free to contribute to the frontier of open source generative AI
+
## Install
```bash
@@ -132,6 +142,74 @@ entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)
That's it!
+## Token Critic
+
+A new paper suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below
+
+```python
+import torch
+from phenaki_pytorch import CViViT, MaskGit, TokenCritic, PhenakiCritic
+
+cvivit = CViViT(
+ dim = 512,
+ codebook_size = 5000,
+ image_size = (256, 128),
+ patch_size = 32,
+ temporal_patch_size = 2,
+ spatial_depth = 4,
+ temporal_depth = 4,
+ dim_head = 64,
+ heads = 8
+)
+
+maskgit = MaskGit(
+ num_tokens = 5000,
+ max_seq_len = 1024,
+ dim = 512,
+ dim_context = 768,
+ depth = 6,
+)
+
+critic = TokenCritic(
+ num_tokens = 5000,
+ max_seq_len = 1024,
+ dim = 512,
+ dim_context = 768,
+ depth = 6
+)
+
+critic_trainer = PhenakiCritic(
+ maskgit = maskgit,
+ critic = critic,
+ cvivit = cvivit
+).cuda()
+
+texts = [
+ 'a whale breaching from afar',
+ 'young girl blowing out candles on her birthday cake',
+ 'fireworks with blue and green sparkles'
+]
+
+videos = torch.randn(3, 3, 3, 256, 128).cuda() # (batch, channels, frames, height, width)
+
+loss = critic_trainer(videos = videos, texts = texts)
+loss.backward()
+```
+
+Then just pass the critic to `Phenaki`
+
+```python
+
+phenaki = Phenaki(
+ cvivit = cvivit,
+ maskgit = maskgit,
+ critic = critic
+).cuda()
+
+```
+
+Now your generations should be greatly improved (but who knows, since this is only a month old research)
+
## Phenaki Trainer (wip)
This repository will also endeavor to allow the researcher to train on text-to-image and then text-to-video. Similarly, for unconditional training, the researcher should be able to first train on images and then fine tune on video. Below is an example for text-to-video
@@ -209,12 +287,11 @@ trainer = PhenakiTrainer(
trainer.train()
```
-Unconditional is as follows
-
-ex. unconditional images and video training
+Token critic training is similarly
```python
import torch
+from torch.utils.data import Dataset
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer
cvivit = CViViT(
@@ -240,36 +317,70 @@ maskgit = MaskGit(
unconditional = False
)
-phenaki = Phenaki(
- cvivit = cvivit,
- maskgit = maskgit
+critic = TokenCritic(
+ num_tokens = 5000,
+ max_seq_len = 1024,
+ dim = 512,
+ dim_context = 768,
+ depth = 6
+)
+
+phenaki_critic = PhenakiCritic(
+ maskgit = maskgit,
+ critic = critic,
+ cvivit = cvivit
).cuda()
-# pass in the folder to images or video
+# mock text video dataset
+# you will have to extend your own, and return the (