From 8531d1c527560153281b6b7646789aa9082e0026 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 28 May 2024 19:40:58 -0700 Subject: [PATCH] [MoE] Test sorting lhs for gmm (#7121) Summary: This pull request adds a test case that sort the lhs and produce group_sizes for gmm. Test Plan: python test/test_gmm.py -v -k test_sorting_input --- test/test_gmm.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/test_gmm.py b/test/test_gmm.py index bd2802d3e06..08483b0dd84 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -251,6 +251,32 @@ def test_histogram_raise(self): max=3, ) + def test_sorting_input(self): + met.clear_all() + top2 = torch.tensor([[0, 2], [1, 3], [1, 2], [2, 3]]).to("xla") + + # We want to create one big batch of tokens that has all top-k choices in it. + # Our tokens will thus be duplicated k-times in the batch. To do this we, + # first flatten the expert choices list and argsort it. This gives us an array + # of length B * K. We then create a tiled arange of size B * K and index + # into the expert choices list. This will give us the set of indices we need + # to gather from the xs to create this big batch. + top_flat = top2.flatten() + lhs_order = top_flat.argsort() + lhs_reverse_order = lhs_order.argsort() + lhs_indices = torch.arange( + top2.shape[0], device="xla").repeat_interleave(2)[lhs_order] + group_sizes = _histogram(top_flat.to(torch.int32), 0, 3) + xm.mark_step() + + # Make sure it doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + self.assertTrue( + torch.all(lhs_indices == torch.tensor([0, 1, 2, 0, 3, 2, 1, 3], + device="xla"))) + self.assertTrue( + torch.all(group_sizes == torch.tensor([1, 2, 3, 2], device="xla"))) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)