From 4a78b41feab25d0ad0f4a1be8ea46b0aaf83d23a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 6 Sep 2023 09:36:47 -0700 Subject: [PATCH] handle if even amount of batch sizes across devices --- musiclm_pytorch/distributed.py | 28 ++++++++++++++++++---------- setup.py | 2 +- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/musiclm_pytorch/distributed.py b/musiclm_pytorch/distributed.py index fa5d640..71364ac 100644 --- a/musiclm_pytorch/distributed.py +++ b/musiclm_pytorch/distributed.py @@ -1,26 +1,34 @@ import torch from torch import nn from torch.autograd import Function -import torch.distributed as distributed +import torch.distributed as dist from einops import rearrange # distributed helpers +def all_gather_same_dim(t): + world_size = dist.get_world_size() + gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)] + dist.all_gather(gathered_tensors, t) + return gathered_tensors + def all_gather_variable_dim(t, dim = 0, sizes = None): - device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size() + device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size() if not exists(sizes): size = torch.tensor(t.shape[dim], device = device, dtype = torch.long) - sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)] - distributed.all_gather(sizes, size) + sizes = all_gather_same_dim(size) sizes = torch.stack(sizes) + if torch.unique(sizes).numel() == 1: + gathered_tensors = all_gather_same_dim(t) + return torch.cat(gathered_tensors, dim = dim), sizes + max_size = sizes.amax().item() - padded_t = pad_dim_to(t, max_size, dim = dim) - gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] - distributed.all_gather(gathered_tensors, padded_t) + padded_t = pad_dim_to(t, max_size, dim = dim) + gathered_tensors = all_gather_same_dim(padded_t) gathered_tensor = torch.cat(gathered_tensors, dim = dim) seq = torch.arange(max_size, device = device) @@ -45,9 +53,9 @@ def forward(ctx, x, dim, sizes, all_reduce_grads): @staticmethod def backward(ctx, grads, _): - batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() + batch_sizes, rank = ctx.batch_sizes, dist.get_rank() if ctx.all_reduce_grads: - distributed.all_reduce(grads) + dist.all_reduce(grads) grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) return grads_by_rank[rank], None, None, None @@ -62,7 +70,7 @@ def __init__( super().__init__() self.dim = dim self.all_reduce_grads = all_reduce_grads - self.is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1 + self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1 def forward( self, diff --git a/setup.py b/setup.py index cd7be6c..0f5a209 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.7', + version = '0.2.8', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang',