Skip to content

Commit

Permalink
get a mulan trainer out, remove wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 6, 2023
1 parent 7dcfc97 commit c138c93
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 12 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
<img src="./musiclm.png" width="450px"></img>

## MusicLM - Pytorch (wip)
## MusicLM - Pytorch

Implementation of <a href="https://google-research.github.io/seanet/musiclm/examples/">MusicLM</a>, Google's new SOTA model for music generation using attention networks, in Pytorch.

They are basically using text-conditioned <a href="https://github.com/lucidrains/audiolm-pytorch">AudioLM</a>, but surprisingly with the embeddings from a text-audio contrastive learned model named <a href="https://arxiv.org/abs/2208.12415">MuLan</a>. MuLan is what will be built out in this repository, with AudioLM modified from the other repository to support the music generation needs here.

Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

- <a href="https://huggingface.co/">🤗 Huggingface</a> for their <a href="https://huggingface.co/docs/accelerate/index">accelerate</a> training library

## Usage

```install
Expand Down Expand Up @@ -134,10 +140,6 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
- [ ] add a version of mulan to <a href="https://github.com/mlfoundations/open_clip">open clip</a>
- [ ] set all the proper spectrogram hyperparameters

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

## Citations

```bibtex
Expand Down
2 changes: 2 additions & 0 deletions musiclm_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from musiclm_pytorch.musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, MusicLM

from musiclm_pytorch.musiclm_pytorch import AudioSpectrogramTransformer, TextTransformer

from musiclm_pytorch.trainer import MuLaNTrainer
16 changes: 12 additions & 4 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ def __init__(
self.pad_id = pad_id
self.norm = LayerNorm(dim)

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
x = None,
Expand All @@ -375,7 +379,7 @@ def forward(
assert exists(x) ^ exists(raw_texts)

if exists(raw_texts):
x = tokenizer.tokenize(raw_texts)
x = tokenizer.tokenize(raw_texts).to(self.device)

if not exists(mask):
mask = x != self.pad_id
Expand Down Expand Up @@ -443,7 +447,7 @@ def get_text_latents(
texts = None,
raw_texts: Optional[List[str]] = None
):
text_embeds = self.text(texts)
text_embeds = self.text(texts, raw_texts = raw_texts)
text_latents = self.text_to_latents(text_embeds)
return l2norm(text_latents)

Expand Down Expand Up @@ -473,7 +477,7 @@ def forward(
numerator = cosine_sim_exp.diag()

if self.decoupled_contrastive_learning:
eye = torch.eye(batch, device = device)
eye = torch.eye(batch, device = device, dtype = torch.bool)
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)

denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
Expand Down Expand Up @@ -581,6 +585,10 @@ def __init__(
self.mulan_embed_quantizer = mulan_embed_quantizer
self.audio_lm = audio_lm

@property
def device(self):
return next(self.parameters()).device

@torch.no_grad()
def forward(
self,
Expand All @@ -589,7 +597,7 @@ def forward(
):
self.eval()

texts = tokenizer.tokenize(raw_texts)
texts = tokenizer.tokenize(raw_texts).to(self.device)

text_embeds = self.mulan_embed_quantizer(texts = texts)

Expand Down
Loading

0 comments on commit c138c93

Please sign in to comment.