diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index 5933ac4e7..df96bcd40 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -362,11 +362,17 @@ class JobConstant(object): # grpc timeout 60s MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60 - # sleep 3s on NetworkFailureReason.WAITING_NODE - MASTER_CLIENT_CHECK_FAULT_TIMEOUT = 1 + # master_client.check_straggler timeout + MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT = 300 - # sleep 3s on NetworkFailureReason.WAITING_NODE - MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT = 1 + # master_client.check_fault_node timeout + MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT = 300 + + # sleep 1s on NetworkFailureReason.WAITING_NODE + MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1 + + # sleep 1s on NetworkFailureReason.WAITING_NODE + MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT = 1 # sleep 5s before next node check round NODE_CHECK_NEXT_ROUND_TIMEOUT = 5 diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 178ef748c..adda3d1b6 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -395,14 +395,14 @@ def check_fault_node(self, timeout=300): result: grpc.NetworkCheckResult = self._get(request) if ( result.reason == NetworkFailureReason.WAITING_NODE - and time.time() - start < timeout - ): - time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_TIMEOUT) + or result.reason == NetworkFailureReason.NO_INIT + ) and time.time() - start < timeout: + time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT) continue break return result.nodes, result.reason - def check_straggler(self, timeout=300): + def check_straggler(self, timeout=3): request = grpc.StragglerExistRequest() start = time.time() while True: @@ -411,7 +411,9 @@ def check_straggler(self, timeout=300): result.reason == NetworkFailureReason.WAITING_NODE and time.time() - start < timeout ): - time.sleep(JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT) + time.sleep( + JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT + ) continue break return result.nodes, result.reason diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 6da0a2c5f..f0d8bd740 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -1426,8 +1426,12 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: elapsed_time, ) success = success or result - fault_nodes, fault_reason = self._client.check_fault_node() - stragglers, straggler_reason = self._client.check_straggler() + fault_nodes, fault_reason = self._client.check_fault_node( + timeout=JobConstant.MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT + ) + stragglers, straggler_reason = self._client.check_straggler( + timeout=JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT + ) logger.info( f"Fault nodes are: {fault_nodes} with {fault_reason} " f" and stragglers are: {stragglers} with {straggler_reason}" diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index d27735dd3..089bcb433 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -778,6 +778,8 @@ def test_run_agent(self): ) # with no fault and no stragglers + agent._client.check_fault_node = mock.MagicMock(return_value=([], "")) + agent._client.check_straggler = mock.MagicMock(return_value=([], "")) agent._run_node_check = mock.MagicMock(return_value=(True, 100)) agent._stop_workers = mock.MagicMock(return_value=True) self.assertTrue(agent.run()) diff --git a/dlrover/python/tests/test_master_client.py b/dlrover/python/tests/test_master_client.py index e216d8dcd..bbf855c80 100644 --- a/dlrover/python/tests/test_master_client.py +++ b/dlrover/python/tests/test_master_client.py @@ -152,7 +152,7 @@ def test_get(self): self.assertEqual(len(nodes), 1) self.assertEqual(nodes[0].type, NodeType.WORKER) - nodes, _ = self._master_client.check_fault_node() + nodes, _ = self._master_client.check_fault_node(timeout=1) self.assertListEqual(nodes, []) round = self._master_client.join_rendezvous(0, 8, "elastic-training")