Skip to content

Commit

Permalink
re impl failed-cnt in job context
Browse files Browse the repository at this point in the history
  • Loading branch information
BalaBalaYi committed Jan 21, 2025
1 parent e889a69 commit 2c5f8e4
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
6 changes: 5 additions & 1 deletion dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def next_rendezvous(self):
)
logger.info(msg)
self._join_rendezvous()

start_pending = 0
while True:
self._check_network_rdzv_for_elastic_training()
Expand Down Expand Up @@ -848,7 +849,10 @@ def _assign_worker_ranks(
return workers

def _initialize_workers(self, worker_group):
logger.info("Start initializing training workers.")
logger.info(
"Start initializing "
f"training({self.__class__.__name__}) workers."
)
start_pending = 0
pend_timeout = float(
self._config.rdzv_configs.get("pend_timeout", "inf")
Expand Down
11 changes: 7 additions & 4 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,13 @@ def _process_error(
error_data: str,
level: str,
) -> bool:
if self._error_monitor and node is not None:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
if node:
if level == TrainingExceptionLevel.NODE_ERROR:
self._job_context.report_failed_node(node.id)
if self._error_monitor:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
return False

def all_running_node_hanged(self):
Expand Down
7 changes: 4 additions & 3 deletions dlrover/python/master/node/event_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RendezvousManager,
)
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.job_context import get_job_context
from dlrover.python.master.watcher.base_watcher import Node

_dlrover_ctx = Context.singleton_instance()
Expand Down Expand Up @@ -230,9 +231,9 @@ def __init__(self, master):
self._min_node = rdzv_manager.get_min_nodes()
else:
self._min_node = sys.maxsize
self._failed_worker_count = 0
self._total_worker_num = self._master.job_manager.get_worker_num()
self._available_worker_num = self._total_worker_num
self._job_context = get_job_context()

def get_job_exit_reason(self, node: Node):
if self._master.task_manager.training_started():
Expand Down Expand Up @@ -271,7 +272,7 @@ def on_node_succeeded(self, node: Node, cluster_context: ClusterContext):
@NodeEventCallback.log_callback_exception
def on_node_failed(self, node: Node, cluster_context):
node.finish_time = datetime.now() # type: ignore
self._failed_worker_count += 1
self._job_context.report_failed_node()
self._stop_job_if_needed(node)
if node.is_unrecoverable_failure():
self._master.speed_monitor.reduce_target_worker_num(
Expand Down Expand Up @@ -327,7 +328,7 @@ def _stop_job_if_needed(self, node: Node):
)
),
)
elif self._failed_worker_count >= max_failure_num:
elif self._job_context.get_failed_node_cnt() >= max_failure_num:
# The job early stops if there are a lot of failed workers.
self._master.request_stop(
success=False,
Expand Down
16 changes: 15 additions & 1 deletion dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import copy
import threading
from typing import Dict, Optional
import time
from typing import Dict, Optional, Union

from dlrover.python.common.constants import NodeType
from dlrover.python.common.node import Node
Expand All @@ -36,6 +37,7 @@ class JobContext(Singleton):
def __init__(self):
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
self._locker = threading.Lock()

def enqueue_action(self, action):
Expand Down Expand Up @@ -179,6 +181,18 @@ def clear_job_nodes(self):
with self._locker:
self._job_nodes = {}

def report_failed_node(self, node_id: Union[int, str] = None):
if node_id is None:
return

node_id = int(node_id)
with self._locker:
if node_id not in self._failed_nodes:
self._failed_nodes[node_id] = int(time.time())

def get_failed_node_cnt(self):
return len(self._failed_nodes)


def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
Expand Down
9 changes: 9 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,18 @@ def test_relaunch_node(self):
should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[6])
self.assertFalse(should_relaunch)

self.assertEqual(manager._job_context.get_failed_node_cnt(), 0)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(manager._job_context.get_failed_node_cnt(), 1)
manager.handle_training_failure(
NodeType.WORKER, 1, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(manager._job_context.get_failed_node_cnt(), 2)

def test_relaunch_under_deleted_event(self):
params = MockK8sPSJobArgs()
Expand Down

0 comments on commit 2c5f8e4

Please sign in to comment.