Skip to content

Commit

Permalink
Update distributed optimizer with new coalescing manager API in PyTor…
Browse files Browse the repository at this point in the history
…ch (NVIDIA#1663)

* get compat with the latest _coalescing_manager

Signed-off-by: Masaki Kozuki <[email protected]>

* Make sure _coalescing_manager is backward compatible

Signed-off-by: Tim Moon <[email protected]>

* Debug coalescing manager kludges with multiple PyTorch versions

PyTorch 1.14.0 and 2.1.0

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
  • Loading branch information
timmoon10 and crcrpar authored May 13, 2023
1 parent 2d876b8 commit 8b7a1ff
Showing 1 changed file with 73 additions and 22 deletions.
95 changes: 73 additions & 22 deletions apex/contrib/optimizers/distributed_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io
import itertools
import threading
import types
from typing import List, Optional

import torch
from torch.distributed.distributed_c10d import _get_default_group
Expand All @@ -30,14 +30,66 @@
from torch.distributed.distributed_c10d import _all_gather_base
all_gather_into_tensor = _all_gather_base

# Add args to coalescing manager if using PyTorch <=1.13.1
# Import context manager to coalesce NCCL calls
# Note: Replace these backward compatibility shims once PyTorch
# exposes a stable public API for coalescing communication.
from torch.distributed.distributed_c10d import _coalescing_manager
if 'device' not in inspect.signature(_coalescing_manager).parameters.keys():
_coalescing_manager_nodevice = _coalescing_manager
if 'device' not in inspect.signature(_coalescing_manager).parameters:
# PyTorch <=1.13.1 does not have device arg
_coalescing_manager_no_device_arg = _coalescing_manager
@contextlib.contextmanager
def _coalescing_manager(group, device, reqs):
with _coalescing_manager_nodevice(group, reqs):
with _coalescing_manager_no_device_arg(group, reqs):
yield
if 'reqs' in inspect.signature(_coalescing_manager).parameters:
# PyTorch <=2.0.1 handles synchronization externally to coalescing
# manager
_coalescing_manager_with_reqs_arg = _coalescing_manager
class _CoalescingManager:
def __init__(self):
self.works: List[torch.distributed.Work] = []
def append(self, work: torch.distributed.Work):
if work:
self.works.append(work)
def wait(self):
for work in self.works:
work.wait()
@contextlib.contextmanager
def _coalescing_manager(
group: Optional[torch.distributed.ProcessGroup] = None,
device: Optional[torch.device] = None,
async_ops: bool = False,
):
assert device is not None
cm = _CoalescingManager()
with _coalescing_manager_with_reqs_arg(
group,
device,
cm.works,
):
yield cm
if not async_ops:
cm.wait()
def _coalescing_manager_append_work(
cm: _CoalescingManager,
work: torch.distributed.Work,
):
"""Add asynchronous request to coalescing manager"""
cm.append(work)
else:
# PyTorch >2.0.1 handles synchronization within coalescing
# manager
def _coalescing_manager_append_work(
cm: torch.distributed._CoalescingManager,
work: torch.distributed.Work,
):
"""Dummy function for backward compatibility
Coalescing manager already keeps track of asynchronous
communication.
"""
pass

# Import optional CUDA kernels
_FOUND_DEPRECATED_FUSED_ADAM = False
Expand Down Expand Up @@ -582,21 +634,20 @@ def __init__(self,

def _broadcast_params(self):
"""Broadcast parameter values from root rank"""
sync_requests = []
process_group = self.process_group
with _coalescing_manager(process_group, self.device, sync_requests):
with _coalescing_manager(process_group, self.device, async_ops=True) as cm:
for param_group in self.param_groups:
for param in param_group['params']:
sync_requests.append(
_coalescing_manager_append_work(
cm,
torch.distributed.broadcast(
param,
src=self.process_group_root,
group=process_group,
async_op=True,
)
)
for req in sync_requests:
req.wait()
cm.wait()

def _make_post_backward_hook(self, param, param_group_id, param_id):
"""Create callback function to call after param generates grad
Expand Down Expand Up @@ -1348,11 +1399,11 @@ def _start_bucket_grad_sync(self, buckets):
with torch.cuda.stream(comm_stream):
for bucket in buckets:
bucket.sync_wait()
sync_requests = []
group = self.distributed_process_group
with _coalescing_manager(group, self.device, sync_requests):
with _coalescing_manager(group, self.device, async_ops=True) as cm:
for bucket in buckets:
bucket.sync_request = (
_coalescing_manager_append_work(
cm,
reduce_scatter_tensor(
bucket.sync_grads_shard,
bucket.grads_bucket,
Expand All @@ -1361,26 +1412,26 @@ def _start_bucket_grad_sync(self, buckets):
async_op=True,
)
)
sync_requests.append(bucket.sync_request)
cm.wait()

# All-reduce over redundant process group
if self.redundant_size > 1:
with torch.cuda.stream(comm_stream):
for bucket in buckets:
bucket.sync_wait()
sync_requests = []
group = self.redundant_process_group
with _coalescing_manager(group, self.device, sync_requests):
with _coalescing_manager(group, self.device, async_ops=True) as cm:
for bucket in buckets:
bucket.sync_request = (
_coalescing_manager_append_work(
cm,
torch.distributed.all_reduce(
bucket.sync_grads_shard,
op=reduce_op,
group=group,
async_op=True,
)
)
sync_requests.append(bucket.sync_request)
cm.wait()

def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress"""
Expand Down Expand Up @@ -1503,19 +1554,19 @@ def _start_bucket_param_sync(self, buckets):
with torch.cuda.stream(comm_stream):
for bucket in buckets:
bucket.sync_wait()
sync_requests = []
group = self.distributed_process_group
with _coalescing_manager(group, self.device, sync_requests):
with _coalescing_manager(group, self.device, async_ops=True) as cm:
for bucket in buckets:
bucket.sync_request = (
_coalescing_manager_append_work(
cm,
all_gather_into_tensor(
bucket.params_bucket,
bucket.params_shard,
group=group,
async_op=True,
)
)
sync_requests.append(bucket.sync_request)
cm.wait()

def _finish_bucket_param_sync(self):
"""Wait for any param synchronizations that are in progress"""
Expand Down

0 comments on commit 8b7a1ff

Please sign in to comment.