diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index df96bcd40..644b31dd1 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -358,15 +358,14 @@ class JobConstant(object): INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MIN = 600 INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MAX = 3600 PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600 + NODE_CHECK_TIMEOUT = 300 # grpc timeout 60s MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60 - # master_client.check_straggler timeout - MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT = 300 - - # master_client.check_fault_node timeout - MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT = 300 + # master_client.check_fault_node/check_straggler timeout value + # must > NODE_CHECK_TIMEOUT + MASTER_CLIENT_CHECK_NODE_TIMEOUT = 360 # sleep 1s on NetworkFailureReason.WAITING_NODE MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1 diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index adda3d1b6..3033ae26c 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -402,7 +402,7 @@ def check_fault_node(self, timeout=300): break return result.nodes, result.reason - def check_straggler(self, timeout=3): + def check_straggler(self, timeout=300): request = grpc.StragglerExistRequest() start = time.time() while True: diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index f0d8bd740..733bfb044 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -1397,6 +1397,9 @@ def __init__( self._check_round = check_round self._config: ElasticLaunchConfig = config + def _get_check_node_timeout(self): + return JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT + def run(self, role: str = DEFAULT_ROLE) -> bool: spec = self._worker_group.spec role = spec.role @@ -1409,7 +1412,9 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: fault_nodes = [] stragglers = [] for i in range(self._check_round): - result, elapsed_time = self._run_node_check() + result, elapsed_time = self._run_node_check( + timeout=JobConstant.NODE_CHECK_TIMEOUT + ) elapsed_time = round(elapsed_time, 3) logger.info( f"Network check time of round {i} is {elapsed_time}" @@ -1427,10 +1432,10 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: ) success = success or result fault_nodes, fault_reason = self._client.check_fault_node( - timeout=JobConstant.MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT + timeout=self._get_check_node_timeout() ) stragglers, straggler_reason = self._client.check_straggler( - timeout=JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT + timeout=self._get_check_node_timeout() ) logger.info( f"Fault nodes are: {fault_nodes} with {fault_reason} " diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 089bcb433..1e960e8d0 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -823,6 +823,21 @@ def test_comm_perf_test(self, mock_run): comm_perf_check(config, entrypoint, args) mock_run.assert_called() + def test_get_check_node_timeout(self): + config = ElasticLaunchConfig(4, 4, 8) + + agent = _create_check_agent( + config=config, + entrypoint="python", + args=[], + rdzv_name="elastic-training", + check_round=2, + ) + self.assertEqual( + agent._get_check_node_timeout(), + JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT, + ) + class MasterRendezvousHandlerTest(unittest.TestCase): def setUp(self) -> None: