Skip to content

Commit

Permalink
complete the residual vector quantization of the joint embedding musi…
Browse files Browse the repository at this point in the history
…c-text space of mulan, with fetching of learned conditioning embeddings, setup so all three transformers in audiolm can have its own
  • Loading branch information
lucidrains committed Feb 1, 2023
1 parent e81b2d4 commit 5a950d5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 9 deletions.
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,32 @@ embeds = mulan.get_audio_latents(wavs) # during training
embeds = mulan.get_text_latents(texts) # during inference
```

To obtain the conditioning embeddings for the three transformers that are a part of `AudioLM`, you must use the `MuLaNEmbedQuantizer` as so

```python
from musiclm_pytorch import MuLaNEmbedQuantizer

wavs = torch.randn(2, 1024)
embeds = mulan.get_audio_latents(wavs)

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
mulan = mulan,
conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers
```

## Todo

- [x] mulan seems to be using decoupled contrastive learning, offer that as an option
- [x] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions

- [ ] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [ ] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [ ] audiolm and mulan goes into musiclm and generate, filter with mulan
- [ ] add a version of mulan to <a href="https://github.com/mlfoundations/open_clip">open clip</a>
Expand Down
3 changes: 2 additions & 1 deletion musiclm_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from musiclm_pytorch.musiclm_pytorch import MuLaN, MusicLM
from musiclm_pytorch.musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, MusicLM

from musiclm_pytorch.musiclm_pytorch import AudioSpectrogramTransformer, TextTransformer
51 changes: 46 additions & 5 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@

from einops import rearrange, repeat, reduce, pack, unpack

from beartype.typing import List, Optional
from beartype.typing import List, Optional, Tuple
from beartype import beartype

# functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def round_down_nearest_multiple(n, divisor):
return n // divisor * divisor

Expand Down Expand Up @@ -449,15 +452,26 @@ class MuLaNEmbedQuantizer(nn.Module):
def __init__(
self,
mulan: MuLaN,
conditioning_dims: Tuple[int, ...],
rq_num_quantizers = 8,
rq_ema_decay = 0.9,
codebook_size = 1024,
namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),

):
super().__init__()
self.mulan = mulan

assert len(namespaces) > 0
self.namespaces = namespaces
self.conditioning_dims = conditioning_dims

assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'

dim = mulan.dim_latent

self.rq = ResidualVQ(
dim = mulan.dim_latent,
dim = dim,
num_quantizers = rq_num_quantizers,
codebook_size = codebook_size,
decay = rq_ema_decay,
Expand All @@ -467,12 +481,33 @@ def __init__(
quantize_dropout = False # no quantize dropout
)

self.dim = dim
self.num_codebooks = rq_num_quantizers

self.cond_embeddings = nn.ParameterDict({})

for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
nn.init.normal_(cond_embeddings, std = 0.02)

self.cond_embeddings[namespace] = cond_embeddings

self.set_default_namespace(namespaces[0])

def set_default_namespace(self, namespace):
self._default_namespace = namespace

def forward(
self,
wavs = None,
texts = None
texts = None,
namespace = None
):
assert exists(wavs) ^ exist(texts)
assert exists(wavs) ^ exists(texts)

namespace = default(namespace, self._default_namespace)
assert namespace in self.namespaces, f'namespace {namespace} not found'
cond_embeddings = self.cond_embeddings[namespace]

with torch.no_grad():
self.mulan.eval()
Expand All @@ -486,7 +521,13 @@ def forward(

_, indices, _ = self.rq(latents)

return indices
batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]

cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)

cond_embeddings = cond_embeddings.gather(2, indices)
return rearrange(cond_embeddings, 'b q 1 d -> b q d')

@beartype
class MusicLM(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand All @@ -22,7 +22,7 @@
'audiolm-pytorch',
'beartype',
'einops>=0.4',
'vector-quantize-pytorch>=0.10.15',
'vector-quantize-pytorch>=1.0.0',
'x-clip',
'torch>=1.6',
'torchaudio'
Expand Down

0 comments on commit 5a950d5

Please sign in to comment.