From 3f6df9a236569087a33cec28556416a429efe689 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:06:46 -0700 Subject: [PATCH 1/3] Update version.txt after 0.14.5 release (#5982) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.14.5 Author - @loadams Co-authored-by: loadams --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 436d0ce0df76..226468ee5b2e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.14.5 +0.14.6 From 30428d0318acdb5f9d95e495c80722d6bd8f2a47 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:27:10 -0500 Subject: [PATCH 2/3] move pynvml install to setup.py (#5840) Only install pynvml on nvidia gpus; not all accelerators --- requirements/requirements.txt | 1 - setup.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 05f88337f3a9..6840d6dbcc98 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,6 @@ hjson ninja numpy -nvidia-ml-py packaging>=20.0 psutil py-cpuinfo diff --git a/setup.py b/setup.py index 2b7555361655..8707209526ad 100755 --- a/setup.py +++ b/setup.py @@ -92,6 +92,10 @@ def get_env_if_set(key, default: typing.Any = ""): 'triton': fetch_requirements('requirements/requirements-triton.txt'), } +# Only install pynvml on nvidia gpus. +if torch_available and get_accelerator().device_name() == 'cuda' and not is_rocm_pytorch: + install_requires.append('nvidia-ml-py') + # Add specific cupy version to both onebit extension variants. if torch_available and get_accelerator().device_name() == 'cuda': cupy = None From 9a3ede7079ad4e329bac7a4bd0799e30c212b7d2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 16 Aug 2024 01:43:45 +0800 Subject: [PATCH 3/3] add moe topk(k>2) gate support (#5881) Notice some users need to use topk > 2 to train MoE models. For example: https://huggingface.co/Qwen/Qwen2-57B-A14B/blob/main/config.json, this PR adds support for topk (k > 2) gates. - add topk (k>2) support - add drop token policy based on position and probabilities. - unit tests --------- Co-authored-by: Kurt Chen Co-authored-by: Jin, Youzhi Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/moe/sharded_moe.py | 88 ++++++++++++++++++++++++++++++++++-- tests/unit/moe/test_moe.py | 49 +++++++++++++++++++- 2 files changed, 131 insertions(+), 6 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 416f01b82e3d..c09a11e213db 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -124,6 +124,8 @@ def einsum(rule, a, b): return a.unsqueeze(2) * b.unsqueeze(1) elif rule == 'se,se->s': return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'se,sec->sec': + return a.unsqueeze(2) * b elif rule == 'sec,sm->ecm': s = a.shape[0] e = a.shape[1] @@ -191,8 +193,8 @@ def top1gating(logits: Tensor, if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function - gates = F.softmax(logits, dim=1) + gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token @@ -369,6 +371,81 @@ def top2gating(logits: Tensor, return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') +def topkgating( + logits: Tensor, + k: int, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, + drop_policy: str = "probs", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements TopKGating on logits.""" + + # everything is in fp32 in this function + # get topk gates + top_gate, top_idx = torch.topk(logits, k=k, dim=1) + # gating decisions + gates = F.softmax(logits, dim=1) + num_experts = int(gates.shape[1]) + + # get topk mask + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) + + mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) + + exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = torch.mean(me * ce) * num_experts * num_experts / k + + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) + # update mask and locations by capacity + + if drop_policy == 'probs': + capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) + mask = torch.logical_and(mask, capacity_mask) + locations = torch.cumsum(mask, dim=0) - 1 + + elif drop_policy == "position": + locations = torch.cumsum(mask, dim=0) - 1 + mask *= torch.lt(locations, capacity) + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = torch.max(exp_counts) + if ep_group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + capacity = new_capacity + + # normalize gates + gates_masked = gates * mask + gates_s = torch.sum(gates_masked, dim=-1, keepdim=True) + denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps) + gates_masked = gates_masked / denom_s + + # dispatch_mask + locations_sc = _one_hot_to_float((locations * mask), capacity) + + combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + class TopKGate(Module): """Gate module which implements Top2Gating as described in Gshard_. :: @@ -401,9 +478,6 @@ def __init__(self, top2_2nd_expert_sampling: bool = True) -> None: super().__init__() - # Only top-1 and top-2 are supported at the moment. - if k != 1 and k != 2: - raise ValueError('Only top-1 and top-2 gatings are supported.') self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.ep_group = ep_group self.k = k @@ -441,9 +515,13 @@ def forward(self, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, self.drop_tokens, self.use_rts, self.ep_group, use_tutel) - else: + elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) + else: + gate_output = topkgating(logits, self.k, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, self.drop_tokens, self.ep_group) if self.wall_clock_breakdown: self.timers(TOPK_GATE_TIMER).stop() diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index fdff9430a4e6..f65d5e2a03bc 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -11,7 +11,7 @@ from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader import deepspeed.comm as dist from deepspeed import get_accelerator -from deepspeed.moe.sharded_moe import top1gating +from deepspeed.moe.sharded_moe import top1gating, topkgating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.utils.torch import required_torch_version @@ -191,3 +191,50 @@ def test(self): drop_tokens=False, use_rts=True, use_tutel=False) + + +class TestTopkGate(DistributedTest): + + def test(self): + + def check_equal(logits, cap, sparse_truth, res): + m, n = logits.shape + dispatch_mask_truth = torch.zeros(m, n, cap) + i, j, k = sparse_truth.t() + dispatch_mask_truth[i, j, k] = 1 + assert (torch.equal(dispatch_mask_truth, res)) + + #s=4 e=4 topk=2 cap=2(s*topk/e) + logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5], + [0.1, 0.11, 0.7, 0.8]]) + logits *= dist.get_rank() + 1 + probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]]) + check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1], + [3, 2, 1]]) + position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits, 2, position_sec_sparse, position_dispatch_res) + + #s=4 e=6 topk=3 cap=2(s*topk/e) + logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034], + [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450], + [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068], + [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]]) + logits2 *= dist.get_rank() + 1 + + #top3 full mask #prob_mask #postion_mask + #0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1 + #0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1 + #0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0 + #1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0 + probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1], + [3, 0, 0], [3, 1, 1], [3, 5, 1]]) + check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1], + [2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]]) + position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)