diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 733bfb044..27b79a851 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -327,6 +327,7 @@ def next_rendezvous(self): ) logger.info(msg) self._join_rendezvous() + start_pending = 0 while True: self._check_network_rdzv_for_elastic_training() @@ -848,7 +849,10 @@ def _assign_worker_ranks( return workers def _initialize_workers(self, worker_group): - logger.info("Start initializing training workers.") + logger.info( + "Start initializing " + f"training({self.__class__.__name__}) workers." + ) start_pending = 0 pend_timeout = float( self._config.rdzv_configs.get("pend_timeout", "inf") diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 7c3e376e6..669db933c 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -1163,10 +1163,13 @@ def _process_error( error_data: str, level: str, ) -> bool: - if self._error_monitor and node is not None: - return self._error_monitor.process_error( - node, restart_count, error_data, level - ) + if node: + if level == TrainingExceptionLevel.NODE_ERROR: + self._job_context.report_failed_node(node.id) + if self._error_monitor: + return self._error_monitor.process_error( + node, restart_count, error_data, level + ) return False def all_running_node_hanged(self): diff --git a/dlrover/python/master/node/event_callback.py b/dlrover/python/master/node/event_callback.py index f7dc1d598..40a057cab 100644 --- a/dlrover/python/master/node/event_callback.py +++ b/dlrover/python/master/node/event_callback.py @@ -29,6 +29,7 @@ RendezvousManager, ) from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.watcher.base_watcher import Node _dlrover_ctx = Context.singleton_instance() @@ -230,9 +231,9 @@ def __init__(self, master): self._min_node = rdzv_manager.get_min_nodes() else: self._min_node = sys.maxsize - self._failed_worker_count = 0 self._total_worker_num = self._master.job_manager.get_worker_num() self._available_worker_num = self._total_worker_num + self._job_context = get_job_context() def get_job_exit_reason(self, node: Node): if self._master.task_manager.training_started(): @@ -271,7 +272,7 @@ def on_node_succeeded(self, node: Node, cluster_context: ClusterContext): @NodeEventCallback.log_callback_exception def on_node_failed(self, node: Node, cluster_context): node.finish_time = datetime.now() # type: ignore - self._failed_worker_count += 1 + self._job_context.report_failed_node() self._stop_job_if_needed(node) if node.is_unrecoverable_failure(): self._master.speed_monitor.reduce_target_worker_num( @@ -327,7 +328,7 @@ def _stop_job_if_needed(self, node: Node): ) ), ) - elif self._failed_worker_count >= max_failure_num: + elif self._job_context.get_failed_node_cnt() >= max_failure_num: # The job early stops if there are a lot of failed workers. self._master.request_stop( success=False, diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py index f35e3bd71..816339b9d 100644 --- a/dlrover/python/master/node/job_context.py +++ b/dlrover/python/master/node/job_context.py @@ -13,7 +13,8 @@ import copy import threading -from typing import Dict, Optional +import time +from typing import Dict, Optional, Union from dlrover.python.common.constants import NodeType from dlrover.python.common.node import Node @@ -36,6 +37,7 @@ class JobContext(Singleton): def __init__(self): self._action_queue = DiagnosisActionQueue() self._job_nodes: Dict[str, Dict[int, Node]] = {} + self._failed_nodes: Dict[int, int] = {} self._locker = threading.Lock() def enqueue_action(self, action): @@ -179,6 +181,18 @@ def clear_job_nodes(self): with self._locker: self._job_nodes = {} + def report_failed_node(self, node_id: Union[int, str] = None): + if node_id is None: + return + + node_id = int(node_id) + with self._locker: + if node_id not in self._failed_nodes: + self._failed_nodes[node_id] = int(time.time()) + + def get_failed_node_cnt(self): + return len(self._failed_nodes) + def get_job_context() -> JobContext: job_context = JobContext.singleton_instance() diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 7c9705e04..249360fb7 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -296,9 +296,18 @@ def test_relaunch_node(self): should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[6]) self.assertFalse(should_relaunch) + self.assertEqual(manager._job_context.get_failed_node_cnt(), 0) manager.handle_training_failure( NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR ) + manager.handle_training_failure( + NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR + ) + self.assertEqual(manager._job_context.get_failed_node_cnt(), 1) + manager.handle_training_failure( + NodeType.WORKER, 1, level=TrainingExceptionLevel.NODE_ERROR + ) + self.assertEqual(manager._job_context.get_failed_node_cnt(), 2) def test_relaunch_under_deleted_event(self): params = MockK8sPSJobArgs()