Skip to content

Commit

Permalink
Fix known issue of job context using. (#1326)
Browse files Browse the repository at this point in the history
* fix known issue

* lint

* optimized

* lint

* lint

* lint
  • Loading branch information
BalaBalaYi authored Nov 8, 2024
1 parent d97cc4c commit ec15193
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 50 deletions.
8 changes: 4 additions & 4 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ def _init_nodes(self):
update_nodes_priority(job_nodes)
self._job_context.update_job_nodes(job_nodes)

self._ps_manager.update_nodes()
self._chief_manager.update_nodes()
self._worker_manager.update_nodes()
self._evaluator_manager.update_nodes()
self._ps_manager.update_nodes_iter()
self._chief_manager.update_nodes_iter()
self._worker_manager.update_nodes_iter()
self._evaluator_manager.update_nodes_iter()

def _init_job_auto_scaler(self):
self._job_autoscaler: JobAutoScaler = new_job_auto_scaler(
Expand Down
15 changes: 11 additions & 4 deletions dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,13 @@ def next_action(
):
return self._action_queue.next_action(instance=instance)

@property
def ps_nodes(self) -> Dict[int, Node]:
def get_mutable_ps_nodes(self) -> Dict[int, Node]:
with self._locker:
if NodeType.PS in self._job_nodes:
return self._job_nodes[NodeType.PS]
return {}

@property
def workers(self) -> Dict[int, Node]:
def get_mutable_worker_nodes(self) -> Dict[int, Node]:
with self._locker:
if NodeType.WORKER in self._job_nodes:
return self._job_nodes[NodeType.WORKER]
Expand Down Expand Up @@ -91,6 +89,15 @@ def _preprocess(self, node_type: str) -> str:
return NodeType.MASTER
return node_type

def update_job_nodes_by_type(self, node_type, job_nodes: Dict[int, Node]):
with self._locker:
if self._job_nodes is None:
self._job_nodes = {}
if node_type not in self._job_nodes:
self._job_nodes[node_type] = {}

self._job_nodes[node_type] = copy.deepcopy(job_nodes)

def update_job_nodes(self, job_nodes: Dict[str, Dict[int, Node]]):
with self._locker:
self._job_nodes = copy.deepcopy(job_nodes)
Expand Down
17 changes: 8 additions & 9 deletions dlrover/python/master/node/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
self._init_training_ps_cluster()

def _ps_nodes(self):
return self._job_context.ps_nodes
return self._job_context.get_mutable_ps_nodes()

def _init_training_ps_cluster(self):
for node in self._ps_nodes().values():
Expand All @@ -93,7 +93,7 @@ def relaunch_node(self, node: Node, remove_exited_node=False):
node.is_released = True
new_id = next(self._node_id_iter)
new_node = node.get_relaunch_node_info(new_id)
self._job_context.update_job_node(new_node)
self._update_node(new_node)
if node in self._training_ps_cluster:
i = self._training_ps_cluster.index(node)
self._training_ps_cluster[i] = self._ps_nodes()[new_node.id]
Expand Down Expand Up @@ -155,7 +155,7 @@ def _scale_up_ps(self, up_num):
critical=True,
service_addr=service_addr,
)
self._job_context.update_job_node(ps)
self._update_node(ps)
new_ps.append(ps)
logger.info("Create PS %s", ps)
return new_ps
Expand Down Expand Up @@ -192,7 +192,7 @@ def process_after_ps_cluster_ready(self):
node.critical = False
node.relaunchable = False
node.is_released = True
self._job_context.update_job_node(node)
self._update_node(node)
if node.id in self._migrated_ps_nodes:
self._migrated_ps_nodes.pop(node.id)
plan.remove_nodes.append(node)
Expand Down Expand Up @@ -267,7 +267,7 @@ def _pre_drop_migrated_ps(self, alive_ps: List[Node]):
):
if node not in self._pre_dropped_ps:
node.migrated = True
self._job_context.update_job_node(node)
self._update_node(node)
self._pre_dropped_ps.append(node)

def get_total_request_cpu(self):
Expand Down Expand Up @@ -324,7 +324,7 @@ def delete_running_ps(self):
)
node.is_released = True
node.status = NodeStatus.DELETED
self._job_context.update_job_node(node)
self._update_node(node)

plan.remove_nodes.append(node)
return plan
Expand Down Expand Up @@ -371,7 +371,7 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0):
service_addr=service_addr,
name=self._new_node_name_fn(NodeType.PS, new_ps_id),
)
self._job_context.update_job_node(new_node)
self._update_node(new_node)
self._migrated_ps_nodes[old_ps_id] = self._ps_nodes()[new_node.id]
logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id)
return new_node
Expand All @@ -380,10 +380,9 @@ def exist_migrated_ps_nodes(self):
return len(self._migrated_ps_nodes) > 0

def is_all_running(self):
nodes = self._job_context.job_nodes_by_type(self._node_type)
running_ps = [
pod_info.id
for pod_info in nodes.values()
for pod_info in self._ps_nodes().values()
if pod_info.status == NodeStatus.RUNNING
]
return len(running_ps) == self._job_resource.ps_num
Expand Down
17 changes: 9 additions & 8 deletions dlrover/python/master/node/training_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def cur_nodes(self):
cur_nodes = [node.name for node in nodes.values()]
return cur_nodes

def update_nodes(self):
def _update_node(self, node: Node):
self._job_context.update_job_node(node)

def update_nodes_iter(self):
nodes = self._job_context.job_nodes_by_type(self._node_type)
self._node_id_iter = itertools.count(len(nodes))
self._node_rank_iter = itertools.count(len(nodes))
Expand All @@ -237,18 +240,16 @@ def remove_node(self, node_id):
logger.error("Unknown deletable worker id: %s" % node_id)
return
worker.is_released = True
self._job_context.update_job_node(worker)
self._update_node(worker)
plan.remove_nodes.append(worker)
return plan

def relaunch_node(self, node: Node, remove_exited_node=False):
plan = ScalePlan()
nodes = self._job_context.job_nodes_by_type(self._node_type)
with self._lock:
new_id = next(self._node_id_iter)
relaunch_node = node.get_relaunch_node_info(new_id)
nodes[new_id] = relaunch_node
self._job_context.update_job_node(relaunch_node)
self._update_node(relaunch_node)
logger.info("Relaunch node %s to %s", node.name, new_id)
plan.launch_nodes.append(
Node(
Expand All @@ -264,7 +265,7 @@ def relaunch_node(self, node: Node, remove_exited_node=False):
)
if remove_exited_node and not node.is_released and node.exited():
node.is_released = True
self._job_context.update_job_node(node)
self._update_node(node)
plan.remove_nodes.append(node)
return plan

Expand All @@ -280,7 +281,7 @@ def reduce_pending_node_resource(self):
reduced = reduce_timeout_pending_node_resource(node)
if reduced:
node.relaunchable = False
self._job_context.update_job_node(node)
self._update_node(node)
node_plan = self.relaunch_node(node)
plan.remove_nodes.append(node)
plan.merge(node_plan)
Expand Down Expand Up @@ -404,7 +405,7 @@ def running_nodes_hanged(self) -> List[bool]:
f"{timeout} from {date_time}!!!"
)
node.hang = hang
self._job_context.update_job_node(node)
self._update_node(node)
node_hang.append(hang)
return node_hang

Expand Down
8 changes: 4 additions & 4 deletions dlrover/python/master/node/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _scale_up_workers(self, up_num):
config_resource=copy.deepcopy(worker_resource),
service_addr=service_addr,
)
self._job_context.update_job_node(new_node)
self._update_node(new_node)
logger.info("Create worker %s", new_node)
plan.launch_nodes.append(new_node)
return plan
Expand Down Expand Up @@ -258,7 +258,7 @@ def migrate_workers(self, workers: Dict[str, NodeResource]):
rank_index=task_id,
name=self._new_node_name_fn(NodeType.WORKER, node_id),
)
self._job_context.update_job_node(new_node)
self._update_node(new_node)
plan.launch_nodes.append(new_node)
plan.remove_nodes.append(old_node)
return plan
Expand Down Expand Up @@ -323,7 +323,7 @@ def verify_restarting_training(self, node_id):
restart = worker.restart_training
# Set False to avoid restart repeatedly.
worker.restart_training = False
self._job_context.update_job_node(worker)
self._update_node(worker)
return restart

def is_training_hang_by_pending(self, total_node_num, job_type) -> bool:
Expand Down Expand Up @@ -367,7 +367,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool:
return False

# collect pending and running nodes
cur_nodes = list(self._job_context.workers.values())
cur_nodes = list(self._job_context.get_mutable_worker_nodes().values())
pending_workers: List[Node] = []
running_workers: List[Node] = []
for node in cur_nodes:
Expand Down
4 changes: 3 additions & 1 deletion dlrover/python/tests/test_job_auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def test_execute_job_optimization_plan(self):

for worker in worker_nodes.values():
worker.status = NodeStatus.RUNNING
self.job_context.update_job_node(worker)
self.job_context.update_job_nodes_by_type(
NodeType.WORKER, worker_nodes
)

manager._scaler.scale = mock.MagicMock(return_value=True)

Expand Down
6 changes: 3 additions & 3 deletions dlrover/python/tests/test_ps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_migrate_parameter_servers(self):
self._elastic_job.get_node_name,
)

nodes = self._job_context.ps_nodes
nodes = self._job_context.get_mutable_ps_nodes()
for node in nodes.values():
node.status = NodeStatus.RUNNING
self._job_context.update_job_node(node)
Expand All @@ -170,13 +170,13 @@ def test_migrate_parameter_servers(self):
self.assertEqual(ps_manager._migrated_ps_nodes[0].id, 2)
self.assertTrue(ps_manager.exist_migrated_ps_nodes())

nodes = self._job_context.ps_nodes
nodes = self._job_context.get_mutable_ps_nodes()
ps_manager._pre_drop_migrated_ps(list(nodes.values()))
self.assertEqual(len(ps_manager._pre_dropped_ps), 0)
for node in nodes.values():
node.status = NodeStatus.RUNNING
self._job_context.update_job_node(node)
nodes = self._job_context.ps_nodes
nodes = self._job_context.get_mutable_ps_nodes()
ps_manager._pre_drop_migrated_ps(list(nodes.values()))
self.assertEqual(len(ps_manager._pre_dropped_ps), 1)

Expand Down
42 changes: 25 additions & 17 deletions dlrover/python/tests/test_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def tearDown(self) -> None:

def test_scale_up_workers(self):
self._worker_manager._scale_up_workers(3)
workers = self.job_context.workers
workers = self.job_context.get_mutable_worker_nodes()
self.assertEqual(len(workers), 8)
self.assertEqual(workers[7].id, 7)

def test_scale_down_workers(self):
workers = list(self.job_context.workers.values())
workers = list(self.job_context.get_mutable_worker_nodes().values())
self._worker_manager._scale_down_workers(2, workers)
released_workers = []
for worker in workers:
Expand All @@ -79,7 +79,7 @@ def test_scale_down_workers(self):
self.assertEqual(len(released_workers), 2)

def test_delete_exited_workers(self):
workers = self.job_context.workers
workers = self.job_context.get_mutable_worker_nodes()
workers[3].status = NodeStatus.FINISHED
self.job_context.update_job_node(workers[3])
workers[4].status = NodeStatus.FAILED
Expand All @@ -93,7 +93,7 @@ def test_delete_exited_workers(self):
)

def test_delete_running_workers(self):
for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.status = NodeStatus.RUNNING
self.job_context.update_job_node(node)
plan = self._worker_manager.delete_running_workers()
Expand All @@ -116,15 +116,15 @@ def test_relaunch_node(self):
self._elastic_job.get_node_service_addr,
self._elastic_job.get_node_name,
)
failed_worker = self.job_context.workers[4]
failed_worker = self.job_context.get_mutable_worker_nodes()[4]
failed_worker.status = NodeStatus.FAILED
failed_worker.max_relaunch_count = 3
self.job_context.update_job_node(failed_worker)
plan = worker_manager.relaunch_node(
failed_worker, remove_exited_node=True
)
self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16)
self.assertEqual(self.job_context.workers[5].id, 5)
self.assertEqual(self.job_context.get_mutable_worker_nodes()[5].id, 5)
self.assertEqual(plan.launch_nodes[0].max_relaunch_count, 3)
self.assertEqual(plan.remove_nodes[0].config_resource.cpu, 16)

Expand Down Expand Up @@ -156,14 +156,14 @@ def test_reduce_pending_node_resource(self):
self._elastic_job.get_node_service_addr,
self._elastic_job.get_node_name,
)
for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.status = NodeStatus.PENDING
node.create_time = datetime.now() + timedelta(days=-1)
self.job_context.update_job_node(node)
plan = worker_manager.reduce_pending_node_resource()
self.assertEqual(len(plan.launch_nodes), 5)

for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.config_resource.gpu_num = 1
self.job_context.update_job_node(node)

Expand All @@ -177,27 +177,31 @@ def test_pending_without_workers(self):
self._elastic_job.get_node_service_addr,
self._elastic_job.get_node_name,
)
for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.status = NodeStatus.FAILED
node.exit_reason = NodeExitReason.FATAL_ERROR
self.job_context.update_job_node(node)
exited = worker_manager.has_exited_worker()
self.assertTrue(exited)

for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.exit_reason = NodeExitReason.KILLED
self.job_context.update_job_node(node)
exited = worker_manager.has_exited_worker()
self.assertFalse(exited)

self.job_context.workers[0].status = NodeStatus.SUCCEEDED
self.job_context.update_job_node(self.job_context.workers[0])
self.job_context.get_mutable_worker_nodes()[
0
].status = NodeStatus.SUCCEEDED
self.job_context.update_job_node(
self.job_context.get_mutable_worker_nodes()[0]
)
exited = worker_manager.has_exited_worker()
self.assertTrue(exited)

wait = worker_manager.wait_worker_restart()
self.assertTrue(wait)
for node in self.job_context.workers.values():
for node in self.job_context.get_mutable_worker_nodes().values():
node.relaunch_count = node.max_relaunch_count
self.job_context.update_job_node(node)

Expand All @@ -213,12 +217,16 @@ def test_verify_restarting_training(self):
)
reset = worker_manager.verify_restarting_training(0)
self.assertFalse(reset)
self.job_context.workers[0].restart_training = True
self.job_context.update_job_node(self.job_context.workers[0])
self.job_context.get_mutable_worker_nodes()[0].restart_training = True
self.job_context.update_job_node(
self.job_context.get_mutable_worker_nodes()[0]
)
reset = worker_manager.verify_restarting_training(0)
self.assertTrue(reset)
self.job_context.workers[0].is_released = True
self.job_context.update_job_node(self.job_context.workers[0])
self.job_context.get_mutable_worker_nodes()[0].is_released = True
self.job_context.update_job_node(
self.job_context.get_mutable_worker_nodes()[0]
)
reset = worker_manager.verify_restarting_training(0)
self.assertFalse(reset)

Expand Down

0 comments on commit ec15193

Please sign in to comment.