Skip to content

Commit

Permalink
fix crash in torch2.6 if TP=1
Browse files Browse the repository at this point in the history
error like "ValueError: Expecting a ProcessGroup, but got a <class
'text_generation_server.utils.dist.FakeGroup'>. rank=0"

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Jan 7, 2025
1 parent 23bc38b commit 54eafc1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions server/text_generation_server/utils/dist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import torch

from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
Expand All @@ -18,10 +18,11 @@ def wait(self):
pass


class FakeGroup:
class FakeGroup(ProcessGroup):
def __init__(self, rank, size):
self._rank = rank
self._size = size
super().__init__(rank, size)

def allreduce(self, *args, **kwargs):
return FakeBarrier()
Expand Down

0 comments on commit 54eafc1

Please sign in to comment.