Skip to content

Commit

Permalink
Refactor failed node counter. (#1450)
Browse files Browse the repository at this point in the history
* re impl failed-cnt in job context

* fix network check status report

* fix ut

* fix ut

* lint

* lint

* ut fix
  • Loading branch information
BalaBalaYi authored Jan 22, 2025
1 parent 17ab888 commit 7a95974
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 11 deletions.
8 changes: 6 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def next_rendezvous(self):
)
logger.info(msg)
self._join_rendezvous()

start_pending = 0
while True:
self._check_network_rdzv_for_elastic_training()
Expand Down Expand Up @@ -858,7 +859,10 @@ def _assign_worker_ranks(
return workers

def _initialize_workers(self, worker_group):
logger.info("Start initializing training workers.")
logger.info(
"Start initializing "
f"training({self.__class__.__name__}) workers."
)
start_pending = 0
pend_timeout = float(
self._config.rdzv_configs.get("pend_timeout", "inf")
Expand Down Expand Up @@ -1489,7 +1493,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
logger.warning("This node is a straggler!")
if self._config.exclude_straggler:
raise NodeCheckFailedError(
"The node is a straggler " "and exits."
"The node is a straggler and exits."
)
return success

Expand Down
11 changes: 7 additions & 4 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,13 @@ def _process_error(
error_data: str,
level: str,
) -> bool:
if self._error_monitor and node is not None:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
if node:
if level == TrainingExceptionLevel.NODE_ERROR:
self._job_context.report_failed_node(node.id)
if self._error_monitor:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
return False

def all_running_node_hanged(self):
Expand Down
7 changes: 4 additions & 3 deletions dlrover/python/master/node/event_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RendezvousManager,
)
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.job_context import get_job_context
from dlrover.python.master.watcher.base_watcher import Node

_dlrover_ctx = Context.singleton_instance()
Expand Down Expand Up @@ -230,9 +231,9 @@ def __init__(self, master):
self._min_node = rdzv_manager.get_min_nodes()
else:
self._min_node = sys.maxsize
self._failed_worker_count = 0
self._total_worker_num = self._master.job_manager.get_worker_num()
self._available_worker_num = self._total_worker_num
self._job_context = get_job_context()

def get_job_exit_reason(self, node: Node):
if self._master.task_manager.training_started():
Expand Down Expand Up @@ -271,7 +272,7 @@ def on_node_succeeded(self, node: Node, cluster_context: ClusterContext):
@NodeEventCallback.log_callback_exception
def on_node_failed(self, node: Node, cluster_context):
node.finish_time = datetime.now() # type: ignore
self._failed_worker_count += 1
self._job_context.report_failed_node(node.id)
self._stop_job_if_needed(node)
if node.is_unrecoverable_failure():
self._master.speed_monitor.reduce_target_worker_num(
Expand Down Expand Up @@ -327,7 +328,7 @@ def _stop_job_if_needed(self, node: Node):
)
),
)
elif self._failed_worker_count >= max_failure_num:
elif self._job_context.get_failed_node_cnt() >= max_failure_num:
# The job early stops if there are a lot of failed workers.
self._master.request_stop(
success=False,
Expand Down
16 changes: 15 additions & 1 deletion dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import copy
import threading
from typing import Dict, Optional
import time
from typing import Dict, Optional, Union

from dlrover.python.common.constants import NodeType
from dlrover.python.common.node import Node
Expand All @@ -36,6 +37,7 @@ class JobContext(Singleton):
def __init__(self):
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
self._locker = threading.Lock()

def enqueue_action(self, action):
Expand Down Expand Up @@ -179,6 +181,18 @@ def clear_job_nodes(self):
with self._locker:
self._job_nodes = {}

def report_failed_node(self, node_id: Union[int, str] = None):
if node_id is None:
return

node_id = int(node_id)
with self._locker:
if node_id not in self._failed_nodes:
self._failed_nodes[node_id] = int(time.time())

def get_failed_node_cnt(self):
return len(self._failed_nodes)


def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_event_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_on_node_failed(self):
worker.relaunch_count = 1
worker.exit_reason = NodeExitReason.FATAL_ERROR
self.event_cb.on_node_failed(worker, None)
self.assertEqual(self.event_cb._failed_worker_count, 1)
self.assertEqual(self.event_cb._job_context.get_failed_node_cnt(), 1)
self.assertTrue(self.master._stop_requested)
self.master._stop_requested = False
_dlrover_ctx.relaunch_always = True
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def test_relaunch_node(self):
manager = create_job_manager(params, SpeedMonitor())
self.assertEqual(manager._ps_relaunch_max_num, 1)
manager.start()

# reset failed nodes for testing
self.job_context._failed_nodes = {}
self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID)

job_nodes = self.job_context.job_nodes()
Expand Down Expand Up @@ -296,9 +299,18 @@ def test_relaunch_node(self):
should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[6])
self.assertFalse(should_relaunch)

self.assertEqual(self.job_context.get_failed_node_cnt(), 0)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(self.job_context.get_failed_node_cnt(), 1)
manager.handle_training_failure(
NodeType.WORKER, 1, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(self.job_context.get_failed_node_cnt(), 2)

def test_relaunch_under_deleted_event(self):
params = MockK8sPSJobArgs()
Expand Down

0 comments on commit 7a95974

Please sign in to comment.