diff --git a/README.md b/README.md
index 6d173d8..656b7c4 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-## MusicLM - Pytorch (wip)
+## MusicLM - Pytorch
Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch.
@@ -8,6 +8,12 @@ They are basically using text-conditioned if you are interested in helping out with the replication with the LAION community
+## Appreciation
+
+- Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
+
+- 🤗 Huggingface for their accelerate training library
+
## Usage
```install
@@ -134,10 +140,6 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
- [ ] add a version of mulan to open clip
- [ ] set all the proper spectrogram hyperparameters
-## Appreciation
-
-- Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
-
## Citations
```bibtex
diff --git a/musiclm_pytorch/__init__.py b/musiclm_pytorch/__init__.py
index 3034325..2d88945 100644
--- a/musiclm_pytorch/__init__.py
+++ b/musiclm_pytorch/__init__.py
@@ -1,3 +1,5 @@
from musiclm_pytorch.musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, MusicLM
from musiclm_pytorch.musiclm_pytorch import AudioSpectrogramTransformer, TextTransformer
+
+from musiclm_pytorch.trainer import MuLaNTrainer
diff --git a/musiclm_pytorch/musiclm_pytorch.py b/musiclm_pytorch/musiclm_pytorch.py
index f7a48b8..42e4143 100644
--- a/musiclm_pytorch/musiclm_pytorch.py
+++ b/musiclm_pytorch/musiclm_pytorch.py
@@ -366,6 +366,10 @@ def __init__(
self.pad_id = pad_id
self.norm = LayerNorm(dim)
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
def forward(
self,
x = None,
@@ -375,7 +379,7 @@ def forward(
assert exists(x) ^ exists(raw_texts)
if exists(raw_texts):
- x = tokenizer.tokenize(raw_texts)
+ x = tokenizer.tokenize(raw_texts).to(self.device)
if not exists(mask):
mask = x != self.pad_id
@@ -443,7 +447,7 @@ def get_text_latents(
texts = None,
raw_texts: Optional[List[str]] = None
):
- text_embeds = self.text(texts)
+ text_embeds = self.text(texts, raw_texts = raw_texts)
text_latents = self.text_to_latents(text_embeds)
return l2norm(text_latents)
@@ -473,7 +477,7 @@ def forward(
numerator = cosine_sim_exp.diag()
if self.decoupled_contrastive_learning:
- eye = torch.eye(batch, device = device)
+ eye = torch.eye(batch, device = device, dtype = torch.bool)
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
@@ -581,6 +585,10 @@ def __init__(
self.mulan_embed_quantizer = mulan_embed_quantizer
self.audio_lm = audio_lm
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
@torch.no_grad()
def forward(
self,
@@ -589,7 +597,7 @@ def forward(
):
self.eval()
- texts = tokenizer.tokenize(raw_texts)
+ texts = tokenizer.tokenize(raw_texts).to(self.device)
text_embeds = self.mulan_embed_quantizer(texts = texts)
diff --git a/musiclm_pytorch/trainer.py b/musiclm_pytorch/trainer.py
new file mode 100644
index 0000000..ae5522b
--- /dev/null
+++ b/musiclm_pytorch/trainer.py
@@ -0,0 +1,366 @@
+import copy
+from math import sqrt
+from random import choice
+from pathlib import Path
+from shutil import rmtree
+from functools import wraps, partial
+
+from typing_extensions import Annotated
+
+from beartype import beartype
+from beartype.door import is_bearable
+from beartype.vale import Is
+from beartype.typing import Union, List, Optional, Tuple, Callable
+
+import torch
+from torch import nn
+from torch.optim import AdamW, Adam
+from torch.utils.data import Dataset, DataLoader, random_split
+from torch.nn.utils.rnn import pad_sequence
+
+from musiclm_pytorch import MuLaN
+
+from einops import rearrange
+
+from accelerate import Accelerator
+
+# for automatically routing data emitted from a dataset to keywords of the transformer wrappers
+
+DATASET_FIELD_TYPE_CONFIG = dict(
+ wavs = Annotated[
+ torch.Tensor,
+ Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
+ ],
+ raw_texts = List[str],
+ texts = Annotated[
+ torch.Tensor,
+ Is[lambda t: t.dtype == torch.long and t.ndim == 2]
+ ],
+)
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg
+ return None
+
+def noop(*args, **kwargs):
+ pass
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def cast_tuple(t):
+ return t if isinstance(t, (tuple, list)) else (t,)
+
+def yes_or_no(question):
+ answer = input(f'{question} (y/n) ')
+ return answer.lower() in ('yes', 'y')
+
+def accum_log(log, new_logs):
+ for key, new_value in new_logs.items():
+ old_value = log.get(key, 0.)
+ log[key] = old_value + new_value
+ return log
+
+# auto data to module keyword argument routing functions
+
+def has_duplicates(tup):
+ counts = dict()
+ for el in tup:
+ if el not in counts:
+ counts[el] = 0
+ counts[el] += 1
+ return any(filter(lambda count: count > 1, counts.values()))
+
+def determine_types(data, config):
+ output = []
+ for el in data:
+ for name, data_type in config.items():
+ if is_bearable(el, data_type):
+ output.append(name)
+ break
+ else:
+ raise TypeError(f'unable to determine type of {data}')
+
+ return tuple(output)
+
+# optimizer functions
+
+def separate_weight_decayable_params(params):
+ wd_params, no_wd_params = [], []
+ for param in params:
+ param_list = no_wd_params if param.ndim < 2 else wd_params
+ param_list.append(param)
+ return wd_params, no_wd_params
+
+def get_optimizer(
+ params,
+ lr = 1e-4,
+ wd = 1e-2,
+ betas = (0.9, 0.99),
+ eps = 1e-8,
+ filter_by_requires_grad = False,
+ group_wd_params = True,
+ **kwargs
+):
+ if filter_by_requires_grad:
+ params = list(filter(lambda t: t.requires_grad, params))
+
+ if wd == 0:
+ return Adam(params, lr = lr, betas = betas, eps = eps)
+
+ if group_wd_params:
+ wd_params, no_wd_params = separate_weight_decayable_params(params)
+
+ params = [
+ {'params': wd_params},
+ {'params': no_wd_params, 'weight_decay': 0},
+ ]
+
+ return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
+
+# dataloader functions
+
+def collate_one_or_multiple_tensors(fn):
+ @wraps(fn)
+ def inner(data):
+ is_one_data = not isinstance(data[0], tuple)
+
+ if is_one_data:
+ data = torch.stack(data)
+ return (data,)
+
+ outputs = []
+ for datum in zip(*data):
+ if is_bearable(datum, Tuple[str, ...]):
+ output = list(datum)
+ else:
+ output = fn(datum)
+
+ outputs.append(output)
+
+ return tuple(outputs)
+
+ return inner
+
+@collate_one_or_multiple_tensors
+def curtail_to_shortest_collate(data):
+ min_len = min(*[datum.shape[0] for datum in data])
+ data = [datum[:min_len] for datum in data]
+ return torch.stack(data)
+
+@collate_one_or_multiple_tensors
+def pad_to_longest_fn(data):
+ return pad_sequence(data, batch_first = True)
+
+def get_dataloader(ds, pad_to_longest = True, **kwargs):
+ collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
+ return DataLoader(ds, collate_fn = collate_fn, **kwargs)
+
+# semantic transformer trainer
+
+@beartype
+class MuLaNTrainer(nn.Module):
+ def __init__(
+ self,
+ mulan: MuLaN,
+ dataset: Dataset,
+ *,
+ num_train_steps = None,
+ batch_size,
+ data_max_length = None,
+ folder = None,
+ lr = 3e-4,
+ grad_accum_every = 1,
+ wd = 0.,
+ max_grad_norm = 0.5,
+ valid_frac = 0.05,
+ random_split_seed = 42,
+ save_model_every = 1000,
+ results_folder = './results',
+ accelerate_kwargs: dict = dict()
+ ):
+ super().__init__()
+ self.accelerator = Accelerator(**accelerate_kwargs)
+
+ self.mulan = mulan
+
+ self.register_buffer('steps', torch.Tensor([0]))
+
+ self.num_train_steps = default(num_train_steps, len(dataset)) # 1 epoch by default
+ self.batch_size = batch_size
+ self.grad_accum_every = grad_accum_every
+
+ # optimizers
+
+ self.optim = get_optimizer(mulan.parameters(), lr = lr, wd = wd)
+
+ # max grad norm
+
+ self.max_grad_norm = max_grad_norm
+
+ # create dataset
+
+ self.ds = dataset
+ self.ds_fields = None
+
+ # split for validation
+
+ if valid_frac > 0:
+ train_size = int((1 - valid_frac) * len(self.ds))
+ valid_size = len(self.ds) - train_size
+ self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
+ self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
+ else:
+ self.valid_ds = self.ds
+ self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
+
+ # dataloader
+
+ self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)
+
+ self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)
+
+ # prepare with accelerator
+
+ (
+ self.mulan,
+ self.optim,
+ self.dl,
+ self.valid_dl
+ ) = self.accelerator.prepare(
+ self.mulan,
+ self.optim,
+ self.dl,
+ self.valid_dl
+ )
+
+ # dataloader iterators
+
+ self.dl_iter = cycle(self.dl)
+ self.valid_dl_iter = cycle(self.valid_dl)
+
+ self.save_model_every = save_model_every
+
+ hps = dict(
+ num_train_steps = num_train_steps,
+ data_max_length = data_max_length,
+ learning_rate = lr
+ )
+
+ self.accelerator.init_trackers("mulan", config = hps)
+
+ # results folder
+
+ self.results_folder = Path(results_folder)
+
+ if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
+ rmtree(str(self.results_folder))
+
+ self.results_folder.mkdir(parents = True, exist_ok = True)
+
+ # to device
+
+ self.mulan.to(self.device)
+
+ def save(self, path):
+ pkg = dict(
+ model = self.accelerator.get_state_dict(self.mulan),
+ optim = self.optim.state_dict()
+ )
+ torch.save(pkg, path)
+
+ def load(self, path):
+ path = Path(path)
+ assert path.exists()
+ pkg = torch.load(str(path))
+
+ mulan = self.accelerator.unwrap_model(self.mulan)
+ mulan.load_state_dict(pkg['model'])
+ self.optim.load_state_dict(pkg['optim'])
+
+ def print(self, msg):
+ self.accelerator.print(msg)
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @property
+ def is_distributed(self):
+ return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
+
+ @property
+ def is_main(self):
+ return self.accelerator.is_main_process
+
+ @property
+ def is_local_main(self):
+ return self.accelerator.is_local_main_process
+
+ def data_tuple_to_kwargs(self, data):
+ if not exists(self.ds_fields):
+ self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
+ assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
+
+ return dict(zip(self.ds_fields, data))
+
+ def train_step(self):
+ device = self.device
+
+ steps = int(self.steps.item())
+
+ self.mulan.train()
+
+ # logs
+
+ logs = {}
+
+ # update vae (generator)
+
+ for _ in range(self.grad_accum_every):
+ data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
+ loss = self.mulan(**data_kwargs)
+
+ self.accelerator.backward(loss / self.grad_accum_every)
+
+ accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
+
+ if exists(self.max_grad_norm):
+ self.accelerator.clip_grad_norm_(self.mulan.parameters(), self.max_grad_norm)
+
+ self.optim.step()
+ self.optim.zero_grad()
+
+ # log
+
+ self.print(f"{steps}: loss: {logs['loss']}")
+ self.accelerator.log({"train_loss": logs['loss']}, step = steps)
+
+ # save model every so often
+
+ if self.is_main and not (steps % self.save_model_every):
+ state_dict = self.mulan.state_dict()
+ model_path = str(self.results_folder / f'mulan.{steps}.pt')
+ torch.save(state_dict, model_path)
+
+ self.print(f'{steps}: saving model to {str(self.results_folder)}')
+
+ self.steps += 1
+ return logs
+
+ def train(self, log_fn: Callable = noop):
+
+ while self.steps < self.num_train_steps:
+ logs = self.train_step()
+ log_fn(logs)
+
+ self.print('training complete')
diff --git a/setup.py b/setup.py
index d4deb72..bce50da 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
- version = '0.0.10',
+ version = '0.0.11',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
@@ -19,9 +19,10 @@
'contrastive learning'
],
install_requires=[
- 'audiolm-pytorch>=0.9.3',
+ 'accelerate',
+ 'audiolm-pytorch>=0.10.4',
'beartype',
- 'einops>=0.4',
+ 'einops>=0.6',
'vector-quantize-pytorch>=1.0.0',
'x-clip',
'torch>=1.6',