Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor failed node counter. #1450

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def next_rendezvous(self):
)
logger.info(msg)
self._join_rendezvous()

start_pending = 0
while True:
self._check_network_rdzv_for_elastic_training()
Expand Down Expand Up @@ -858,7 +859,10 @@ def _assign_worker_ranks(
return workers

def _initialize_workers(self, worker_group):
logger.info("Start initializing training workers.")
logger.info(
"Start initializing "
f"training({self.__class__.__name__}) workers."
)
start_pending = 0
pend_timeout = float(
self._config.rdzv_configs.get("pend_timeout", "inf")
Expand Down Expand Up @@ -1489,7 +1493,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
logger.warning("This node is a straggler!")
if self._config.exclude_straggler:
raise NodeCheckFailedError(
"The node is a straggler " "and exits."
"The node is a straggler and exits."
)
return success

Expand Down
11 changes: 7 additions & 4 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,13 @@ def _process_error(
error_data: str,
level: str,
) -> bool:
if self._error_monitor and node is not None:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
if node:
if level == TrainingExceptionLevel.NODE_ERROR:
self._job_context.report_failed_node(node.id)
if self._error_monitor:
return self._error_monitor.process_error(
node, restart_count, error_data, level
)
return False

def all_running_node_hanged(self):
Expand Down
7 changes: 4 additions & 3 deletions dlrover/python/master/node/event_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RendezvousManager,
)
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.job_context import get_job_context
from dlrover.python.master.watcher.base_watcher import Node

_dlrover_ctx = Context.singleton_instance()
Expand Down Expand Up @@ -230,9 +231,9 @@ def __init__(self, master):
self._min_node = rdzv_manager.get_min_nodes()
else:
self._min_node = sys.maxsize
self._failed_worker_count = 0
self._total_worker_num = self._master.job_manager.get_worker_num()
self._available_worker_num = self._total_worker_num
self._job_context = get_job_context()

def get_job_exit_reason(self, node: Node):
if self._master.task_manager.training_started():
Expand Down Expand Up @@ -271,7 +272,7 @@ def on_node_succeeded(self, node: Node, cluster_context: ClusterContext):
@NodeEventCallback.log_callback_exception
def on_node_failed(self, node: Node, cluster_context):
node.finish_time = datetime.now() # type: ignore
self._failed_worker_count += 1
self._job_context.report_failed_node(node.id)
self._stop_job_if_needed(node)
if node.is_unrecoverable_failure():
self._master.speed_monitor.reduce_target_worker_num(
Expand Down Expand Up @@ -327,7 +328,7 @@ def _stop_job_if_needed(self, node: Node):
)
),
)
elif self._failed_worker_count >= max_failure_num:
elif self._job_context.get_failed_node_cnt() >= max_failure_num:
# The job early stops if there are a lot of failed workers.
self._master.request_stop(
success=False,
Expand Down
16 changes: 15 additions & 1 deletion dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import copy
import threading
from typing import Dict, Optional
import time
from typing import Dict, Optional, Union

from dlrover.python.common.constants import NodeType
from dlrover.python.common.node import Node
Expand All @@ -36,6 +37,7 @@
def __init__(self):
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
self._locker = threading.Lock()

def enqueue_action(self, action):
Expand Down Expand Up @@ -179,6 +181,18 @@
with self._locker:
self._job_nodes = {}

def report_failed_node(self, node_id: Union[int, str] = None):
if node_id is None:
return

Check warning on line 186 in dlrover/python/master/node/job_context.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/node/job_context.py#L186

Added line #L186 was not covered by tests

node_id = int(node_id)
with self._locker:
if node_id not in self._failed_nodes:
self._failed_nodes[node_id] = int(time.time())

def get_failed_node_cnt(self):
return len(self._failed_nodes)


def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_event_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_on_node_failed(self):
worker.relaunch_count = 1
worker.exit_reason = NodeExitReason.FATAL_ERROR
self.event_cb.on_node_failed(worker, None)
self.assertEqual(self.event_cb._failed_worker_count, 1)
self.assertEqual(self.event_cb._job_context.get_failed_node_cnt(), 1)
self.assertTrue(self.master._stop_requested)
self.master._stop_requested = False
_dlrover_ctx.relaunch_always = True
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def test_relaunch_node(self):
manager = create_job_manager(params, SpeedMonitor())
self.assertEqual(manager._ps_relaunch_max_num, 1)
manager.start()

# reset failed nodes for testing
self.job_context._failed_nodes = {}
self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID)

job_nodes = self.job_context.job_nodes()
Expand Down Expand Up @@ -296,9 +299,18 @@ def test_relaunch_node(self):
should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[6])
self.assertFalse(should_relaunch)

self.assertEqual(self.job_context.get_failed_node_cnt(), 0)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
manager.handle_training_failure(
NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(self.job_context.get_failed_node_cnt(), 1)
manager.handle_training_failure(
NodeType.WORKER, 1, level=TrainingExceptionLevel.NODE_ERROR
)
self.assertEqual(self.job_context.get_failed_node_cnt(), 2)

def test_relaunch_under_deleted_event(self):
params = MockK8sPSJobArgs()
Expand Down
Loading