From 90021a1a53df195cdae509d6ccffbcec102ef022 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 6 Sep 2023 08:12:45 -0700 Subject: [PATCH] further cleanup the distributed code --- musiclm_pytorch/distributed.py | 54 +++++++++++++++++------------- musiclm_pytorch/musiclm_pytorch.py | 17 +++++----- setup.py | 2 +- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/musiclm_pytorch/distributed.py b/musiclm_pytorch/distributed.py index 756107b..fa5d640 100644 --- a/musiclm_pytorch/distributed.py +++ b/musiclm_pytorch/distributed.py @@ -1,4 +1,5 @@ import torch +from torch import nn from torch.autograd import Function import torch.distributed as distributed @@ -33,37 +34,42 @@ def all_gather_variable_dim(t, dim = 0, sizes = None): return gathered_tensor, sizes -class AllGather(Function): +class AllGatherFunction(Function): @staticmethod - def forward(ctx, x, dim, sizes): - assert distributed.is_initialized() and distributed.get_world_size() > 1 + def forward(ctx, x, dim, sizes, all_reduce_grads): x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) - ctx.batch_sizes = batch_sizes.tolist() ctx.dim = dim - return x, batch_sizes - - @staticmethod - def backward(ctx, grads, _): - batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() - grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) - return grads_by_rank[rank], None, None - -all_gather = AllGather.apply - -class AllGatherAllReduceGrads(Function): - @staticmethod - def forward(ctx, x, dim, sizes): - assert distributed.is_initialized() and distributed.get_world_size() > 1 - x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) + ctx.all_reduce_grads = all_reduce_grads ctx.batch_sizes = batch_sizes.tolist() - ctx.dim = dim return x, batch_sizes @staticmethod def backward(ctx, grads, _): - distributed.all_reduce(grads) batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() - grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) - return grads_by_rank[rank], None, None + if ctx.all_reduce_grads: + distributed.all_reduce(grads) -all_gather_all_reduce_grads = AllGatherAllReduceGrads.apply + grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) + return grads_by_rank[rank], None, None, None + +class AllGather(nn.Module): + def __init__( + self, + dim, + *, + all_reduce_grads = False + ): + super().__init__() + self.dim = dim + self.all_reduce_grads = all_reduce_grads + self.is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1 + + def forward( + self, + x, + sizes = None + ): + if not self.is_distributed: + return x, None + + return AllGatherFunction.apply(x, self.dim, sizes, self.all_reduce_grads) diff --git a/musiclm_pytorch/musiclm_pytorch.py b/musiclm_pytorch/musiclm_pytorch.py index ee3352d..77f90e9 100644 --- a/musiclm_pytorch/musiclm_pytorch.py +++ b/musiclm_pytorch/musiclm_pytorch.py @@ -11,7 +11,7 @@ from audiolm_pytorch.utils import AudioConditionerBase import torch.distributed as dist -from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads +from musiclm_pytorch.distributed import AllGather from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ @@ -266,7 +266,7 @@ def __init__( self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.decoupled_contrastive_learning = decoupled_contrastive_learning - self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1 + self.all_gather = AllGather(dim = 2) @property def device(self): @@ -281,9 +281,9 @@ def forward(self, audio_latents, text_latents): batch = audio_latents.shape[1] - if self.needs_all_gather: + if self.all_gather.is_distributed: latents = torch.stack((audio_latents, text_latents)) - latents, _ = all_gather(latents, 2, None) + latents, _ = self.all_gather(latents) audio_latents, text_latents = latents sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) @@ -320,7 +320,7 @@ def __init__( self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias) - self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1 + self.all_gather = AllGather(dim = 1, all_reduce_grads = True) @property def device(self): @@ -335,8 +335,7 @@ def forward(self, audio_latents, text_latents): if text_latents.ndim == 2: text_latents = rearrange(text_latents, '... -> 1 ...') - if self.needs_all_gather: - text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None) + text_latents, rank_sizes = self.all_gather(text_latents) n = text_latents.shape[1] @@ -346,8 +345,8 @@ def forward(self, audio_latents, text_latents): labels = torch.eye(n, device = device) - if self.needs_all_gather: - labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0) + if exists(rank_sizes): + labels_by_ranks = labels.split(rank_sizes.tolist(), dim = 0) labels = labels_by_ranks[dist.get_rank()] labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims) diff --git a/setup.py b/setup.py index 7497b98..989a95c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.5', + version = '0.2.6', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang',