Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Jan 23, 2025
1 parent 0f56424 commit 82bd905
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 59 deletions.
92 changes: 34 additions & 58 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
List,
Tuple,
Union,
Expand Down Expand Up @@ -864,6 +863,7 @@ def __init__(
f"got {default_communicator}"
)
self._default_communicator: Optional[Communicator] = default_communicator
self._default_communicator_id: Optional[str] = None

self._default_type_hint: ChannelOutputType = SharedMemoryType(
buffer_size_bytes=self._buffer_size_bytes,
Expand Down Expand Up @@ -932,16 +932,6 @@ def __init__(
# Mapping from the actor handle to the node ID that the actor is on.
# A None actor handle means the actor is the driver.
self.actor_to_node_id: Dict[Optional["ray.actor.ActorHandle"], str] = {}

# This is set to true when type hint of `transport="nccl"` is used.
self._use_default_nccl_group = False
# This is set to the specified custom communicator
# if there exists a type hint of `transport=custom_communicator`.
self._custom_communicator_p2p: Optional[Communicator] = None
# The NCCL group ID for P2P send/recv operations.
self._communicator_id_p2p: Optional[str] = None
# All the NCCL group IDs for P2P send/recv and collective operations.
self._communicator_ids: Set[str] = set()
# The index of the current execution. It is incremented each time
# the DAG is executed.
self._execution_index: int = -1
Expand Down Expand Up @@ -976,18 +966,10 @@ def _create_proxy_actor() -> "ray.actor.ActorHandle":
# we can lazily release the native buffers
self._destructed_ref_idxs: Dict[int, Set[Optional[int]]] = defaultdict(set)

@property
def communicator_id_p2p(self) -> Optional[str]:
return self._communicator_id_p2p

@property
def is_teardown(self) -> bool:
return self._is_teardown

@property
def communicator_ids(self) -> Set[str]:
return self._communicator_ids

def get_id(self) -> str:
"""
Get the unique ID of the compiled DAG.
Expand Down Expand Up @@ -1027,6 +1009,10 @@ def _preprocess(self) -> None:
communicator_to_actors: Dict[
Optional[Communicator], Set["ray.actor.ActorHandle"]
] = {}
communicator_to_type_hint: Dict[
Optional[Communicator],
Set["ray.experimental.channel.torch_tensor_type.TorchTensorType"],
] = {}
collective_ops: Set[_CollectiveOperation] = set()

input_attributes: Set[str] = set()
Expand Down Expand Up @@ -1121,10 +1107,16 @@ def _preprocess(self) -> None:
if dag_node.type_hint.requires_nccl():
communicator = self._select_communicator(dag_node)
communicator_to_actors[communicator].add(actor_handle)

communicator_to_type_hint[communicator].add(dag_node.type_hint)
# Collect NCCL collective operations.
if isinstance(dag_node, CollectiveOutputNode):
collective_ops.add(dag_node.collective_op)
communicator = self._select_communicator(dag_node)
communicator_to_actors[communicator].update(
dag_node.collective_op.actor_handles
)
communicator_to_type_hint[communicator].add(
dag_node.collective_op.type_hint
)
assert not self._overlap_gpu_communication, (
"Currently, the overlap_gpu_communication option is not "
"supported for NCCL collective operations. Please set "
Expand Down Expand Up @@ -1210,7 +1202,9 @@ def _preprocess(self) -> None:
if upstream_task.dag_node.type_hint.requires_nccl():
communicator = self._select_communicator(upstream_task.dag_node)
communicator_to_actors[communicator].add(downstream_actor_handle)

communicator_to_type_hint[communicator].add(
upstream_task.dag_node.type_hint
)
# Check that all specified input attributes, e.g., InputNode()["x"],
# are used in the DAG.
_check_unused_dag_input_attributes(output_node, input_attributes)
Expand All @@ -1219,7 +1213,7 @@ def _preprocess(self) -> None:

self._resolve_auto_transport(auto_transport_tasks, communicator_to_actors)

self._init_communicators(communicator_to_actors, collective_ops)
self._init_communicators(communicator_to_actors, communicator_to_type_hint)

if direct_input:
self._input_num_positional_args = 1
Expand All @@ -1234,20 +1228,14 @@ def _init_communicators(
communicator_to_actors: Dict[
Optional[Communicator], Set["ray.actor.ActorHandle"]
],
collective_ops: Set[
"ray.experimental.channel.collective_node._CollectiveOperation"
communicator_to_type_hint: Dict[
Optional[Communicator],
Set["ray.experimental.channel.torch_tensor_type.TorchTensorType"],
],
) -> None:
"""
Initialize communicators for the DAG.
"""
# Initialize and cache a NCCL group for each custom NCCL group. All the
# custom NCCL groups are initialized before the default NCCL groups.
custom_communicator_to_id: Dict[Communicator, str] = {}
# Initialize and cache a NCCL group for each set of actors. A set of actors
# can perform P2P send/recv and collective operations. If there are multiple
# custom NCCL groups for a set of actors, only one is cached.
actors_to_communicator_id: Dict[FrozenSet["ray.actor.ActorHandle"], str] = {}
for custom_communicator, actors in communicator_to_actors.items():
if None in actors:
raise ValueError("Driver cannot participate in the NCCL group.")
Expand All @@ -1257,24 +1245,10 @@ def _init_communicators(
custom_communicator,
self._overlap_gpu_communication,
)
custom_communicator_to_id[custom_communicator] = communicator_id
actors = frozenset(actors)
if actors not in actors_to_communicator_id:
actors_to_communicator_id[actors] = communicator_id

# If a custom communicator is specified for collective actors, initialize and
# cache the communicator ID.
for collective_op in collective_ops:
type_hint = collective_op.type_hint
custom_communicator = type_hint.get_custom_communicator()
if custom_communicator:
communicator_id = collective_op.init_communicator(
custom_communicator_to_id.get(custom_communicator, None)
)
custom_communicator_to_id[custom_communicator] = communicator_id
actors = frozenset(collective_op.actor_handles)
if actors not in actors_to_communicator_id:
actors_to_communicator_id[actors] = communicator_id
for type_hint in communicator_to_type_hint[custom_communicator]:
type_hint.set_communicator_id(communicator_id)
if custom_communicator == self._default_communicator:
self._default_communicator_id = communicator_id

def _select_communicator(
self, dag_node: "ray.dag.DAGNode"
Expand All @@ -1283,7 +1257,14 @@ def _select_communicator(
If custom_communicator is provided (i.e., not None), use it.
Otherwise, use the default communicator.
"""
custom_communicator = dag_node.type_hint.get_custom_communicator()
from ray.dag.collective_node import CollectiveOutputNode

if isinstance(dag_node, CollectiveOutputNode):
custom_communicator = (
dag_node.collective_op.type_hint.get_custom_communicator()
)
else:
custom_communicator = dag_node.type_hint.get_custom_communicator()
if custom_communicator is not None:
return custom_communicator
if not self._create_default_communicator:
Expand Down Expand Up @@ -1443,10 +1424,6 @@ def _get_or_compile(
visited.add(cur_idx)

task = self.idx_to_task[cur_idx]
type_hint = task.dag_node.type_hint
if type_hint.requires_nccl():
type_hint.set_communicator_id(self._communicator_id_p2p)

if (
isinstance(task.dag_node, ClassMethodNode)
and task.dag_node.is_class_method_call
Expand Down Expand Up @@ -1517,7 +1494,7 @@ def _get_or_compile(
fn.remote(
do_allocate_channel,
reader_and_node_list,
type_hint,
task.dag_node.type_hint,
driver_actor_id,
)
)
Expand Down Expand Up @@ -2036,8 +2013,7 @@ def teardown(self, kill_actors: bool = False):
logger.exception("Error cancelling worker task")
pass

for communicator_id in outer._communicator_ids:
_destroy_communicator(communicator_id)
_destroy_communicator(outer._default_communicator_id)

logger.info("Waiting for worker tasks to exit")
self.wait_teardown(kill_actors=kill_actors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def _init_communicator(
custom_communicator: A custom NCCL group to initialize.
use_communication_streams: Whether to use dedicated send and recv
streams for communication. If True, communication and computation
can be overlapped to improve perfomrance.
can be overlapped to improve performance.
"""
ctx = ChannelContext.get_current()

Expand Down

0 comments on commit 82bd905

Please sign in to comment.