Skip to content

Commit

Permalink
add patch dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 5, 2023
1 parent 5aaee64 commit 7dcfc97
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,14 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
}
```

```bibtex
@article{Liu2022PatchDropoutEV,
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.07220}
}
```

*The only truth is music.* - Jack Kerouac
28 changes: 27 additions & 1 deletion musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ def forward(self, x, mask = None):

return x

# Patch Dropout - https://arxiv.org/abs/2208.07220

class PatchDropout(nn.Module):
def __init__(self, prob):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob

def forward(self, x, force_keep_all = False):
if not self.training or self.prob == 0. or force_keep_all:
return x

b, n, _, device = *x.shape, x.device

batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * (1 - self.prob)))
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

return x[batch_indices, patch_indices_keep]

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
Expand All @@ -219,7 +240,8 @@ def __init__(
spec_aug_stretch_factor = 0.8,
spec_aug_freq_mask = 80,
spec_aug_time_mask = 80,
dual_patchnorm = True
dual_patchnorm = True,
patch_dropout_prob = 0.5
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -264,6 +286,8 @@ def __init__(

self.norm = LayerNorm(dim)

self.patch_dropout = PatchDropout(patch_dropout_prob)

def forward(self, x):
x = self.spec(x)

Expand Down Expand Up @@ -294,6 +318,8 @@ def forward(self, x):

x = rearrange(x, 'b ... c -> b (...) c')

x = self.patch_dropout(x)

x = self.transformer(x)

# final global average and norm (most recent papers show this is superior to CLS token)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 7dcfc97

Please sign in to comment.