Skip to content

Commit

Permalink
optimization (#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
BalaBalaYi authored Nov 8, 2024
1 parent ec15193 commit 3639e6e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 33 deletions.
16 changes: 8 additions & 8 deletions dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def next_action(
):
return self._action_queue.next_action(instance=instance)

def get_mutable_ps_nodes(self) -> Dict[int, Node]:
with self._locker:
if NodeType.PS in self._job_nodes:
return self._job_nodes[NodeType.PS]
return {}
def get_mutable_ps_nodes(self):
return self.get_mutable_job_nodes(NodeType.PS)

def get_mutable_worker_nodes(self):
return self.get_mutable_job_nodes(NodeType.WORKER)

def get_mutable_worker_nodes(self) -> Dict[int, Node]:
def get_mutable_job_nodes(self, node_type) -> Dict[int, Node]:
with self._locker:
if NodeType.WORKER in self._job_nodes:
return self._job_nodes[NodeType.WORKER]
if node_type in self._job_nodes:
return self._job_nodes[node_type]
return {}

def job_nodes(self) -> Dict[str, Dict[int, Node]]:
Expand Down
29 changes: 15 additions & 14 deletions dlrover/python/master/node/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ def __init__(
self._node_id_iter = itertools.count(self._job_resource.ps_num)
self._init_training_ps_cluster()

def _ps_nodes(self):
return self._job_context.get_mutable_ps_nodes()

def _init_training_ps_cluster(self):
for node in self._ps_nodes().values():
ps_nodes = self._get_nodes()
for node in ps_nodes.values():
alive = node.status in [
NodeStatus.INITIAL,
NodeStatus.PENDING,
Expand All @@ -96,7 +94,7 @@ def relaunch_node(self, node: Node, remove_exited_node=False):
self._update_node(new_node)
if node in self._training_ps_cluster:
i = self._training_ps_cluster.index(node)
self._training_ps_cluster[i] = self._ps_nodes()[new_node.id]
self._training_ps_cluster[i] = self._get_nodes()[new_node.id]
logger.info("Relaunch node %s to %s", node.name, new_id)
plan.launch_nodes.append(
Node(
Expand Down Expand Up @@ -201,7 +199,8 @@ def process_after_ps_cluster_ready(self):
def _get_alive_ps(self) -> List[Node]:
"""Get all running PS pods"""
alive_ps = []
for node in self._ps_nodes().values():
ps_nodes = self._get_nodes()
for node in ps_nodes.values():
if node.status == NodeStatus.RUNNING and not node.is_released:
alive_ps.append(node)
return alive_ps
Expand All @@ -215,7 +214,8 @@ def get_next_training_ps_cluster(self):
return self._next_training_ps_cluster

all_new_ps_ready = True
for node in self._ps_nodes().values():
ps_nodes = self._get_nodes()
for node in ps_nodes.values():
if self._wait_ps_node(node):
all_new_ps_ready = False
break
Expand All @@ -236,7 +236,8 @@ def has_ps_failure(self):
Check whether there is PS failure and the master does not relaunch
the failed PS node.
"""
for node in self._ps_nodes().values():
ps_nodes = self._get_nodes()
for node in ps_nodes.values():
if node.timeout(_dlrover_ctx.seconds_to_wait_failed_ps):
return True
return False
Expand Down Expand Up @@ -293,7 +294,7 @@ def get_ready_for_new_ps_cluster(self):
def get_ps_addrs(self):
"""Get the address list of ps services"""
ps_addrs = {}
nodes = self._ps_nodes()
nodes = self._get_nodes()
for ps in list(nodes.values()):
if (
ps.id not in self._migrated_ps_nodes
Expand All @@ -309,7 +310,7 @@ def get_ps_addrs(self):
def delete_running_ps(self):
"""Delete all running ps pods"""
plan = ScalePlan()
nodes = self._ps_nodes()
nodes = self._get_nodes()
for node in list(nodes.values()):
if (
node.status in [NodeStatus.RUNNING, NodeStatus.PENDING]
Expand Down Expand Up @@ -344,7 +345,7 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0):
old_ps_id = int(name.split("-")[-1])
if old_ps_id in self._migrated_ps_nodes:
return
nodes = self._ps_nodes()
nodes = self._get_nodes()
if old_ps_id not in nodes:
logger.error(f"not found PS-{old_ps_id} in job")
return
Expand Down Expand Up @@ -372,7 +373,7 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0):
name=self._new_node_name_fn(NodeType.PS, new_ps_id),
)
self._update_node(new_node)
self._migrated_ps_nodes[old_ps_id] = self._ps_nodes()[new_node.id]
self._migrated_ps_nodes[old_ps_id] = self._get_nodes()[new_node.id]
logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id)
return new_node

Expand All @@ -382,7 +383,7 @@ def exist_migrated_ps_nodes(self):
def is_all_running(self):
running_ps = [
pod_info.id
for pod_info in self._ps_nodes().values()
for pod_info in self._get_nodes().values()
if pod_info.status == NodeStatus.RUNNING
]
return len(running_ps) == self._job_resource.ps_num
Expand Down Expand Up @@ -428,7 +429,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool:
return False

# collect pending and running nodes
cur_nodes = list(self._ps_nodes().values())
cur_nodes = list(self._get_nodes().values())
pending_ps: List[Node] = []
running_ps: List[Node] = []
for node in cur_nodes:
Expand Down
3 changes: 3 additions & 0 deletions dlrover/python/master/node/training_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def cur_nodes(self):
cur_nodes = [node.name for node in nodes.values()]
return cur_nodes

def _get_nodes(self):
return self._job_context.get_mutable_job_nodes(self._node_type)

def _update_node(self, node: Node):
self._job_context.update_job_node(node)

Expand Down
22 changes: 11 additions & 11 deletions dlrover/python/master/node/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
def is_chief_running(self):
"""The chief worker with id=0 is responsible to initialize
variables in TensorFlow 1.x PS strategy"""
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for node in nodes.values():
if node.status == NodeStatus.RUNNING:
return True
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
def is_chief_running(self):
"""The chief worker with id=0 is responsible to initialize
variables in TensorFlow 1.x PS strategy"""
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for node in nodes.values():
if node.status == NodeStatus.RUNNING:
return True
Expand Down Expand Up @@ -140,7 +140,7 @@ def adjust_worker(self, worker_resource: NodeGroupResource):
)
)
alive_workers = []
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for worker in nodes.values():
if worker.status in ALIVE_STATUS:
alive_workers.append(worker)
Expand Down Expand Up @@ -192,7 +192,7 @@ def _scale_down_workers(self, down_num, running_workers: List[Node]):
def delete_exited_workers(self):
"""Delete failed, succeed, finished workers."""
plan = ScalePlan()
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
with self._lock:
for worker in nodes.values():
if (
Expand All @@ -210,7 +210,7 @@ def delete_exited_workers(self):

def delete_running_workers(self):
plan = ScalePlan()
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for worker in nodes.values():
if not worker.critical and worker.status in [
NodeStatus.RUNNING,
Expand Down Expand Up @@ -239,7 +239,7 @@ def remove_noncritical_worker(self, worker_id):
def migrate_workers(self, workers: Dict[str, NodeResource]):
"""Migrate workers with the new resource"""
plan = ScalePlan()
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for name, resource in workers.items():
old_node_id = int(name.split("-")[-1])
old_node = nodes[old_node_id]
Expand Down Expand Up @@ -269,7 +269,7 @@ def remove_not_joined_rdzv_workers(self, worker_ranks: List[int]):
worker_ranks: The rank of worker which does not join rendezvous.
"""
plan = ScalePlan()
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for node_id, node in nodes.items():
if node.rank_index in worker_ranks:
p = self.remove_node(node.id)
Expand All @@ -280,7 +280,7 @@ def remove_not_joined_rdzv_workers(self, worker_ranks: List[int]):

def has_exited_worker(self):
"""Check whether there is exited worker except evicted workers."""
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for worker in nodes.values():
if (
worker.exit_reason == NodeExitReason.FATAL_ERROR
Expand All @@ -291,7 +291,7 @@ def has_exited_worker(self):

def wait_worker_restart(self):
"""Check whether there are workers tha have remaining retries."""
nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
for worker in nodes.values():
if (
worker.exit_reason == NodeExitReason.KILLED
Expand Down Expand Up @@ -367,7 +367,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool:
return False

# collect pending and running nodes
cur_nodes = list(self._job_context.get_mutable_worker_nodes().values())
cur_nodes = list(self._get_nodes().values())
pending_workers: List[Node] = []
running_workers: List[Node] = []
for node in cur_nodes:
Expand Down Expand Up @@ -496,7 +496,7 @@ def is_training_hang_by_insufficient_worker(self) -> bool:
f"Is training hang by insufficient worker with timeout: {timeout}."
)

nodes = self._job_context.job_nodes_by_type(self._node_type)
nodes = self._get_nodes()
cur_nodes = list(nodes.values())

# collect available nodes
Expand Down

0 comments on commit 3639e6e

Please sign in to comment.