Skip to content

Commit

Permalink
Update test_ulysses.py
Browse files Browse the repository at this point in the history
Skip v100 test
  • Loading branch information
samadejacobs authored Aug 20, 2024
1 parent 76c67c0 commit cb7c20e
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from transformers import AutoModel
from unit.common import DistributedTest
from deepspeed.sequence.layer import _SeqAllToAll
from unit.util import skip_on_arch


#Use mesh device to create data and sequence parallel group
class TestUlyssesUtils(DistributedTest):
world_size = 4

def test_mesh_device_creation(self) -> None:
skip_on_arch(min_arch=8)
model = AutoModel.from_pretrained('bert-base-uncased')
sp_size = 2
dp_size = 2
Expand Down Expand Up @@ -44,6 +46,7 @@ class TestUlyssesAll2All(DistributedTest):
world_size = 4

def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
skip_on_arch(min_arch=8)
model = AutoModel.from_pretrained('bert-base-uncased')
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
#4D tensor : b,s,h,d or s,b,h,d
Expand Down

0 comments on commit cb7c20e

Please sign in to comment.