Skip to content

Commit

Permalink
waiting when check_fault_nodes return NO_INIT (#1440)
Browse files Browse the repository at this point in the history
* waiting when check_fault_nodes return NO_INIT

* reduce check_fault_node test timeout

* clean up time consuming stuff in unit test

* restore check_fault_node default timeout

---------

Co-authored-by: Ma Jie Yue <[email protected]>
  • Loading branch information
majieyue and Ma Jie Yue authored Jan 15, 2025
1 parent 7ba180c commit 7257b60
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
14 changes: 10 additions & 4 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,17 @@ class JobConstant(object):
# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_TIMEOUT = 1
# master_client.check_straggler timeout
MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT = 300

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT = 1
# master_client.check_fault_node timeout
MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT = 300

# sleep 1s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1

# sleep 1s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT = 1

# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5
Expand Down
12 changes: 7 additions & 5 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,14 @@ def check_fault_node(self, timeout=300):
result: grpc.NetworkCheckResult = self._get(request)
if (
result.reason == NetworkFailureReason.WAITING_NODE
and time.time() - start < timeout
):
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_TIMEOUT)
or result.reason == NetworkFailureReason.NO_INIT
) and time.time() - start < timeout:
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT)
continue
break
return result.nodes, result.reason

def check_straggler(self, timeout=300):
def check_straggler(self, timeout=3):
request = grpc.StragglerExistRequest()
start = time.time()
while True:
Expand All @@ -411,7 +411,9 @@ def check_straggler(self, timeout=300):
result.reason == NetworkFailureReason.WAITING_NODE
and time.time() - start < timeout
):
time.sleep(JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT)
time.sleep(
JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT
)
continue
break
return result.nodes, result.reason
Expand Down
8 changes: 6 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,8 +1426,12 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
elapsed_time,
)
success = success or result
fault_nodes, fault_reason = self._client.check_fault_node()
stragglers, straggler_reason = self._client.check_straggler()
fault_nodes, fault_reason = self._client.check_fault_node(
timeout=JobConstant.MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT
)
stragglers, straggler_reason = self._client.check_straggler(
timeout=JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT
)
logger.info(
f"Fault nodes are: {fault_nodes} with {fault_reason} "
f" and stragglers are: {stragglers} with {straggler_reason}"
Expand Down
2 changes: 2 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,8 @@ def test_run_agent(self):
)

# with no fault and no stragglers
agent._client.check_fault_node = mock.MagicMock(return_value=([], ""))
agent._client.check_straggler = mock.MagicMock(return_value=([], ""))
agent._run_node_check = mock.MagicMock(return_value=(True, 100))
agent._stop_workers = mock.MagicMock(return_value=True)
self.assertTrue(agent.run())
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_get(self):
self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].type, NodeType.WORKER)

nodes, _ = self._master_client.check_fault_node()
nodes, _ = self._master_client.check_fault_node(timeout=1)
self.assertListEqual(nodes, [])

round = self._master_client.join_rendezvous(0, 8, "elastic-training")
Expand Down

0 comments on commit 7257b60

Please sign in to comment.