diff --git a/musiclm_pytorch/distributed.py b/musiclm_pytorch/distributed.py index 32cca72..756107b 100644 --- a/musiclm_pytorch/distributed.py +++ b/musiclm_pytorch/distributed.py @@ -49,3 +49,21 @@ def backward(ctx, grads, _): return grads_by_rank[rank], None, None all_gather = AllGather.apply + +class AllGatherAllReduceGrads(Function): + @staticmethod + def forward(ctx, x, dim, sizes): + assert distributed.is_initialized() and distributed.get_world_size() > 1 + x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) + ctx.batch_sizes = batch_sizes.tolist() + ctx.dim = dim + return x, batch_sizes + + @staticmethod + def backward(ctx, grads, _): + distributed.all_reduce(grads) + batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() + grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) + return grads_by_rank[rank], None, None + +all_gather_all_reduce_grads = AllGatherAllReduceGrads.apply diff --git a/musiclm_pytorch/musiclm_pytorch.py b/musiclm_pytorch/musiclm_pytorch.py index 818ebff..bea3616 100644 --- a/musiclm_pytorch/musiclm_pytorch.py +++ b/musiclm_pytorch/musiclm_pytorch.py @@ -11,7 +11,7 @@ from audiolm_pytorch.utils import AudioConditionerBase import torch.distributed as dist -from musiclm_pytorch.distributed import all_gather +from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ @@ -320,23 +320,37 @@ def __init__( self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias) + self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1 + @property def device(self): return next(self.parameters()).device def forward(self, audio_latents, text_latents): + device = self.device + if audio_latents.ndim == 2: audio_latents = rearrange(audio_latents, '... -> 1 ...') if text_latents.ndim == 2: text_latents = rearrange(text_latents, '... -> 1 ...') - n = audio_latents.shape[1] + if self.needs_all_gather: + text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None) + + n = text_latents.shape[1] sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) sims = sims * self.temperatures.exp() + self.bias - labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims) + + labels = torch.eye(n, device = device) + + if self.needs_all_gather: + labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0) + labels = labels_by_ranks[dist.get_rank()] + + labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims) return -F.logsigmoid(labels * sims).sum() / n diff --git a/setup.py b/setup.py index 9050232..8322db0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.3', + version = '0.2.4', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang',