Skip to content

Commit

Permalink
[transformer] UCC async test (NVIDIA#1417)
Browse files Browse the repository at this point in the history
* add test

* update batch sizes

* update batch sizes

* small updates

* delete comment

* add async comm

* add sync if needed

* update tests

* remove redundant imports

* code cleanup

* minor updates

* update dtype for comparison

* fix dtypes

* fix typo

* modify sizes and use common_utils.find_free_port

* fix typo and use double precision

* revert some changes, create test for profiling on L1

* remove redundant line

* revert UCC_TLS and add sync to fwd_bwd

* code clean up

* code clean up

* modify BERT test

* add comment
  • Loading branch information
Aidyn-A authored Jul 20, 2022
1 parent 809043f commit a29a698
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 32 deletions.
36 changes: 24 additions & 12 deletions apex/transformer/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,47 @@ def _run_p2pops(
async_comm: bool = False
):
ops = []
p2p_group = parallel_state.get_pipeline_model_parallel_group()
default_group = parallel_state.get_model_parallel_group()

need_to_sync = p2p_group.name() != default_group.name()

if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.isend,
tensor=tensor_send_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.isend,
tensor=tensor_send_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(recv_next_op)
if len(ops) > 0:
if need_to_sync:
torch.cuda.synchronize()

reqs = torch.distributed.batch_isend_irecv(ops)
if async_comm:
assert len(reqs) == len(ops)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _forward_backward_pipelining_with_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
Expand Down Expand Up @@ -218,6 +219,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
Expand Down Expand Up @@ -265,6 +267,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
Expand All @@ -275,6 +278,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
Expand Down Expand Up @@ -359,6 +363,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
Expand All @@ -380,6 +385,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
Expand All @@ -401,6 +407,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
Expand Down
2 changes: 1 addition & 1 deletion apex/transformer/testing/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _setup_pre_spawn(self) -> None:

self._has_ucx_tls = "UCX_TLS" in os.environ
if not self._has_ucx_tls:
os.environ["UCX_TLS"] = "tcp,cuda_copy"
os.environ["UCX_TLS"] = "tcp,cuda"
print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"]))

def tearDown(self) -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/L0/run_transformer/run_bert_minimal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,12 @@ def train(
failure = None
init = True
try:
for virtual_pipeline_model_parallel_size in (2, None):
virtual_pipeline_model_parallel_sizes = (None, 2,)
if HAS_TORCH_UCC:
# Deliberately skipping test with interleaved schedule for BERT model.
# It deadlocks on hybrid UCC/NCCL backend.
virtual_pipeline_model_parallel_sizes = (None,)
for virtual_pipeline_model_parallel_size in virtual_pipeline_model_parallel_sizes:
args = global_vars.get_args()
async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
data_idx = 0
Expand Down
56 changes: 38 additions & 18 deletions tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ def init_weights(m):
return init_weights


def get_dtype_for_comparison():
if(torch.cuda.get_device_capability() >= (8, 0)):
return torch.float64
return torch.float32


def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> Tuple[torch.Tensor, List[torch.Tensor]]:
model = []
data = torch.ones(global_batch_shape, dtype=torch.double)
dtype = get_dtype_for_comparison()
data = torch.ones(global_batch_shape, dtype=dtype)
for i in range(total_layers):
w = torch.ones((hidden_size, hidden_size), dtype=torch.double) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size, dtype=torch.double)
w = torch.ones((hidden_size, hidden_size), dtype=dtype) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size, dtype=dtype)

w.requires_grad_()
b.requires_grad_()
Expand Down Expand Up @@ -187,7 +194,8 @@ def _forward_backward_test_impl(
deallocate_pipeline_output=deallocate_pipeline_outputs,
)

if dtype == torch.double:
if dtype == get_dtype_for_comparison():
torch.cuda.synchronize()
hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE
total_layers = pipeline_model_parallel_world_size
Expand Down Expand Up @@ -221,44 +229,56 @@ def _forward_backward_test_impl(

parallel_state.destroy_model_parallel()

def test_no_pipelining(self):
def test_learning_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)

def test_no_pipelining_inference(self):
def test_inference_no_pipelining(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)

def test_pipelining_without_interleaving(self):
def test_learning_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)

def test_pipelining_async(self):
def test_inference_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
True, forward_backward_pipelining_without_interleaving, None, None
)

def test_pipelining_without_interleaving_inference(self):
def test_learning_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)

def test_pipelining_inference_async(self):
def test_inference_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)

@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving(self):
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)

@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving_inference(self):
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)

@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)

@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)


class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase):

Expand All @@ -283,10 +303,10 @@ def _test_hybrid_backends(self, forward_only: bool) -> None:
):
self._run_hybrid_distributed_backend(forward_only)

def test_pipelining_without_interleaving_ucc_for_p2p(self):
def test_learning_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(False)

def test_pipelining_without_interleaving_inference_ucc_for_p2p(self):
def test_inference_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(True)


Expand Down
Loading

0 comments on commit a29a698

Please sign in to comment.